Skip to content

Mamba 2

This model was released on 2024-05-31 and added to Hugging Face Transformers on 2024-08-06.

PyTorch

Mamba 2 is based on the state space duality (SSD) framework which connects structured state space models (SSMs) and attention variants. It uses a more efficient SSD algorithm that is 2-8x faster than Mamba and modifies the architecture to enable tensor parallelism and a grouped-value attention (GVA) head structure.

You can find all the original Mamba 2 checkpoints under the State Space Models organization, but the examples shown below use mistralai/Mamba-Codestral-7B-v0.1 because a Hugging Face implementation isn’t supported yet for the original checkpoints.

Other Mamba 2-based architectures include Bamba, FalconH1, and Zamba2.

The example below demonstrates how to generate text with Pipeline, AutoModel, and from the command line.

import torch
from transformers import pipeline
pipeline = pipeline(
task="text-generation",
model="mistralai/Mamba-Codestral-7B-v0.1",
dtype=torch.bfloat16,
device=0
)
pipeline("Plants create energy through a process known as")
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mamba-Codestral-7B-v0.1")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mamba-Codestral-7B-v0.1", dtype=torch.bfloat16, device_map="auto")
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to(model.device)
output = model.generate(**input_ids)
print(tokenizer.decode(output[0], skip_special_tokens=True))
Terminal window
echo -e "Plants create energy through a process known as" | transformers run --task text-generation --model mistralai/Mamba-Codestral-7B-v0.1 --device 0

Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the Quantization overview for more available quantization backends.

The example below uses torchao to only quantize the weights to 4-bit integers.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mamba-Codestral-7B-v0.1")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mamba-Codestral-7B-v0.1", dtype=torch.bfloat16, quantization_config=quantization_config, device_map="auto")
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to(model.device)
output = model.generate(**input_ids)
print(tokenizer.decode(output[0], skip_special_tokens=True))
  • Codestral Mamba has groups=8 which are similar to the number of kv heads in an attention-based model.

  • Codestral Mamba has two different forward passes, torch_forward or cuda_kernels_forward, and their results are expected to be slightly different.

    • torch_forward without compilation is 3-4x faster than cuda_kernels_forward.
    • cuda_kernels_forward uses the original CUDA kernels if they’re available in your environment. It is slower during prefill because it requires a “warmup run” due to the higher CPU overhead (see these comments for more details).
  • There are no positional embeddings in this model, but there is an attention_mask and a specific logic to mask out hidden states in two places in the case of batched generation (see this comment for more details). This (and the addition of the reimplemented Mamba 2 kernels) results in a slight discrepancy between batched and cached generation.

  • The SSM algorithm heavily relies on tensor contractions, which have matmul equivalents but the order of operations is slightly different. This makes the difference greater at smaller precisions.

  • Hidden states that correspond to padding tokens is shutdown in 2 places and is mostly tested with left-padding. Right-padding propagates noise down the line and is not guaranteed to yield satisfactory results. tokenizer.padding_side = "left" ensures you are using the correct padding side.

  • The example below demonstrates how to fine-tune Mamba 2 with PEFT.

from datasets import load_dataset
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer
model_id = "mistralai/Mamba-Codestral-7B-v0.1"
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = SFTConfig(dataset_text_field="quote", gradient_checkpointing=True, per_device_train_batch_size=4)
lora_config = LoraConfig(target_modules=["x_proj", "embeddings", "in_proj", "out_proj"])
trainer = SFTTrainer(
model=model_id,
args=training_args,
train_dataset=dataset,
peft_config=lora_config,
)
trainer.train()

[[autodoc]] Mamba2Config

[[autodoc]] Mamba2Model - forward

[[autodoc]] Mamba2ForCausalLM - forward