Mamba 2
This model was released on 2024-05-31 and added to Hugging Face Transformers on 2024-08-06.
Mamba 2
Section titled “Mamba 2”Mamba 2 is based on the state space duality (SSD) framework which connects structured state space models (SSMs) and attention variants. It uses a more efficient SSD algorithm that is 2-8x faster than Mamba and modifies the architecture to enable tensor parallelism and a grouped-value attention (GVA) head structure.
You can find all the original Mamba 2 checkpoints under the State Space Models organization, but the examples shown below use mistralai/Mamba-Codestral-7B-v0.1 because a Hugging Face implementation isn’t supported yet for the original checkpoints.
Other Mamba 2-based architectures include Bamba, FalconH1, and Zamba2.
The example below demonstrates how to generate text with Pipeline, AutoModel, and from the command line.
import torchfrom transformers import pipeline
pipeline = pipeline( task="text-generation", model="mistralai/Mamba-Codestral-7B-v0.1", dtype=torch.bfloat16, device=0)pipeline("Plants create energy through a process known as")import torchfrom transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mamba-Codestral-7B-v0.1")model = AutoModelForCausalLM.from_pretrained("mistralai/Mamba-Codestral-7B-v0.1", dtype=torch.bfloat16, device_map="auto")input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to(model.device)
output = model.generate(**input_ids)print(tokenizer.decode(output[0], skip_special_tokens=True))echo -e "Plants create energy through a process known as" | transformers run --task text-generation --model mistralai/Mamba-Codestral-7B-v0.1 --device 0Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the Quantization overview for more available quantization backends.
The example below uses torchao to only quantize the weights to 4-bit integers.
import torchfrom transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)tokenizer = AutoTokenizer.from_pretrained("mistralai/Mamba-Codestral-7B-v0.1")model = AutoModelForCausalLM.from_pretrained("mistralai/Mamba-Codestral-7B-v0.1", dtype=torch.bfloat16, quantization_config=quantization_config, device_map="auto")input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to(model.device)
output = model.generate(**input_ids)print(tokenizer.decode(output[0], skip_special_tokens=True))-
Codestral Mamba has
groups=8which are similar to the number of kv heads in an attention-based model. -
Codestral Mamba has two different forward passes,
torch_forwardorcuda_kernels_forward, and their results are expected to be slightly different.torch_forwardwithout compilation is 3-4x faster thancuda_kernels_forward.cuda_kernels_forwarduses the original CUDA kernels if they’re available in your environment. It is slower during prefill because it requires a “warmup run” due to the higher CPU overhead (see these comments for more details).
-
There are no positional embeddings in this model, but there is an
attention_maskand a specific logic to mask out hidden states in two places in the case of batched generation (see this comment for more details). This (and the addition of the reimplemented Mamba 2 kernels) results in a slight discrepancy between batched and cached generation. -
The SSM algorithm heavily relies on tensor contractions, which have matmul equivalents but the order of operations is slightly different. This makes the difference greater at smaller precisions.
-
Hidden states that correspond to padding tokens is shutdown in 2 places and is mostly tested with left-padding. Right-padding propagates noise down the line and is not guaranteed to yield satisfactory results.
tokenizer.padding_side = "left"ensures you are using the correct padding side. -
The example below demonstrates how to fine-tune Mamba 2 with PEFT.
from datasets import load_datasetfrom peft import LoraConfigfrom trl import SFTConfig, SFTTrainer
model_id = "mistralai/Mamba-Codestral-7B-v0.1"dataset = load_dataset("Abirate/english_quotes", split="train")training_args = SFTConfig(dataset_text_field="quote", gradient_checkpointing=True, per_device_train_batch_size=4)lora_config = LoraConfig(target_modules=["x_proj", "embeddings", "in_proj", "out_proj"])trainer = SFTTrainer( model=model_id, args=training_args, train_dataset=dataset, peft_config=lora_config,)trainer.train()Mamba2Config
Section titled “Mamba2Config”[[autodoc]] Mamba2Config
Mamba2Model
Section titled “Mamba2Model”[[autodoc]] Mamba2Model - forward
Mamba2LMHeadModel
Section titled “Mamba2LMHeadModel”[[autodoc]] Mamba2ForCausalLM - forward