Skip to content

Gemma

This model was released on 2024-03-13 and added to Hugging Face Transformers on 2024-02-21.

PyTorch FlashAttention SDPA Tensor parallelism

Gemma is a family of lightweight language models with pretrained and instruction-tuned variants, available in 2B and 7B parameters. The architecture is based on a transformer decoder-only design. It features Multi-Query Attention, rotary positional embeddings (RoPE), GeGLU activation functions, and RMSNorm layer normalization.

The instruction-tuned variant was fine-tuned with supervised learning on instruction-following data, followed by reinforcement learning from human feedback (RLHF) to align the model outputs with human preferences.

You can find all the original Gemma checkpoints under the Gemma release.

The example below demonstrates how to generate text with Pipeline or the AutoModel class, and from the command line.

import torch
from transformers import pipeline
pipeline = pipeline(
task="text-generation",
model="google/gemma-2b",
dtype=torch.bfloat16,
device_map="auto",
)
pipeline("LLMs generate text through a process known as", max_new_tokens=50)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa"
)
input_text = "LLMs generate text through a process known as"
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)
outputs = model.generate(**input_ids, max_new_tokens=50, cache_implementation="static")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Terminal window
echo -e "LLMs generate text through a process known as" | transformers run --task text-generation --model google/gemma-2b --device 0

Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the Quantization overview for more available quantization backends.

The example below uses bitsandbytes to only quantize the weights to int4.

#!pip install bitsandbytes
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4"
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-7b",
quantization_config=quantization_config,
device_map="auto",
attn_implementation="sdpa"
)
input_text = "LLMs generate text through a process known as."
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)
outputs = model.generate(
**input_ids,
max_new_tokens=50,
cache_implementation="static"
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Use the AttentionMaskVisualizer to better understand what tokens the model can and cannot attend to.

from transformers.utils.attention_visualizer import AttentionMaskVisualizer
visualizer = AttentionMaskVisualizer("google/gemma-2b")
visualizer("LLMs generate text through a process known as")
  • The original Gemma models support standard kv-caching used in many transformer-based language models. You can use use the default DynamicCache instance or a tuple of tensors for past key values during generation. This makes it compatible with typical autoregressive generation workflows.

    import torch
    from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
    model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b",
    dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="sdpa"
    )
    input_text = "LLMs generate text through a process known as"
    input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)
    past_key_values = DynamicCache(config=model.config)
    outputs = model.generate(**input_ids, max_new_tokens=50, past_key_values=past_key_values)
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))

[[autodoc]] GemmaConfig

[[autodoc]] GemmaTokenizer

[[autodoc]] GemmaTokenizerFast

[[autodoc]] GemmaModel - forward

[[autodoc]] GemmaForCausalLM - forward

[[autodoc]] GemmaForSequenceClassification - forward

[[autodoc]] GemmaForTokenClassification - forward