Source code for mistral_v0_2.model.mistral_lm

from einshard import einshard
import jax
from jax import Array
import jax.numpy as jnp
from transformers import MistralForCausalLM

from .kvcache import KVCache
from .mistral_model import MistralModelParams, convert_mistral_model_params, forward_mistral_model, shard_mistral_model_params
from .rotary_embedding import RotaryValues, make_rotary_values
from ..lib.array_conversion import pt2jax

MistralLMParams = tuple[MistralModelParams, Array]

[docs] def convert_mistral_lm_params(model: MistralForCausalLM) -> MistralLMParams: """ Converts MistralForCausalLM (PyTorch tensor) to MistralLMParams(JAX Array). Args: model (MistralForCausalLM): Mistral LM. Returns: MistralLMParams: The converted Mistral lm parameters. """ model_params = convert_mistral_model_params(model.model) lm_head = pt2jax(model.lm_head.weight.T) return model_params, lm_head
def convert_back_mistral_lm_params(params: MistralLMParams) -> MistralForCausalLM: pass
[docs] def shard_mistral_lm_params(params: MistralLMParams) -> MistralLMParams: """ Shard the MistralLMParams params for distributed computing. Args: params (MistralLMParams): The Mistral LM parameters. Returns: MistralLMParams: The Mistral LM parameters modified with tensor parallelism, allowing for distributed computation across multiple devices. """ model_params, lm_head = params model_params = shard_mistral_model_params(model_params) lm_head = einshard(lm_head, '... -> * ...') return model_params, lm_head
[docs] def forward_mistral_lm(params: MistralLMParams, input_ids: Array, qk_mask: Array, rotary_values: RotaryValues, kv_cache_pre: KVCache) -> tuple[Array, KVCache]: """ Executes the forward pass of mistral lm. Args: params (MistralLMParams): The decoder parameters. input_ids (Array): The input sequences to the decoder block. qk_mask (Array): The qk mask for the attention mechanism, determining which parts of the sequence are allowed to attend to each other. rotary_values (RotaryValues): Rotary positional embeddings values. kv_cache_pre (KVCache): The previous KVCache. Returns: tuple[Array, KVCache]: A tuple containing the output sequence after mistral lm, and previous KVCache. """ model_params, lm_head = params outputs, kv_cache_pre = forward_mistral_model(model_params, input_ids, qk_mask, rotary_values, kv_cache_pre) logits = outputs @ lm_head return logits, kv_cache_pre
[docs] def test_forward_mistral_lm(model: MistralForCausalLM) -> None: """ Tests the forward Mistral LM. This function is designed to validate the functionality and correctness of the Mistral LM with JAX. Args: model (MistralForCausalLM): PyTorch Mistral model to compare with this implementation. Returns: None. """ from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-v0.1') tokenizer.pad_token = tokenizer.eos_token sentences = ['I have a cat.'] inputs = tokenizer(sentences, padding=True, return_tensors='pt') input_ids = inputs.input_ids attn_mask = inputs.attention_mask outputs_pt = model(input_ids, attn_mask)[0] outputs_pt_to_jax = pt2jax(outputs_pt) # load on CPU first to avoid OOM cpu_device = jax.devices('cpu')[0] with jax.default_device(cpu_device): params = convert_mistral_lm_params(model) params = shard_mistral_lm_params(params) input_ids_jax = pt2jax(input_ids) attn_mask_jax = pt2jax(attn_mask).astype(jnp.bool_) qk_mask = jnp.tril(jnp.einsum('bi,bj->bij', attn_mask_jax, attn_mask_jax))[:, None, None] batch_size, seq_len = input_ids_jax.shape rotary_values = make_rotary_values(batch_size, seq_len) outputs_jax, _ = forward_mistral_lm(params, input_ids_jax, qk_mask, rotary_values, None) outputs_pt_to_jax = jnp.where(attn_mask_jax[:, :, None], outputs_pt_to_jax, 0.) outputs_jax = jnp.where(attn_mask_jax[:, :, None], outputs_jax, 0.) assert jnp.allclose(outputs_pt_to_jax, outputs_jax, atol=1e-5)