Source code for mistral_v0_2.model.mistral_model
from jax import Array
from transformers.models.mistral.modeling_mistral import MistralModel
from .decoder import DecoderParams, convert_decoder_params, forward_decoder, shard_decoder_params
from .embedding import EmbeddingParams, convert_embedding_params, forward_embedding, shard_embedding_params
from .kvcache import KVCache
from .rms_norm import RMSNormParams, convert_rms_norm_params, forward_rms_norm, shard_rms_norm_params
from .rotary_embedding import RotaryValues
MistralModelParams = tuple[EmbeddingParams, DecoderParams, RMSNormParams]
[docs]
def convert_mistral_model_params(model: MistralModel) -> MistralModelParams:
"""
Converts MistralModel (PyTorch tensor) to MistralModelParams(JAX Array).
Args:
model (MistralModel): Mistral v0.2 model.
Returns:
MistralModelParams: The converted Mistral parameters.
"""
embedding = convert_embedding_params(model.embed_tokens)
decoder_layers = convert_decoder_params(model.layers)
norm = convert_rms_norm_params(model.norm)
return embedding, decoder_layers, norm
def convert_back_mistral_model_params():
pass
[docs]
def shard_mistral_model_params(params: MistralModelParams):
"""
Shard the MistralModelParams params for distributed computing.
Args:
params (MistralModelParams): The Mistral parameters.
Returns:
MistralModelParams: The Mistral parameters modified with tensor parallelism, allowing for distributed computation across multiple devices.
"""
embedding, decoder_layers, norm = params
embedding = shard_embedding_params(embedding)
decoder_layers = shard_decoder_params(decoder_layers)
norm = shard_rms_norm_params(norm)
return embedding, decoder_layers, norm
[docs]
def forward_mistral_model(params: MistralModelParams, input_ids: Array, qk_mask: Array, rotary_values: RotaryValues, kv_cache_pre: KVCache) -> tuple[Array, KVCache]:
"""
Executes the forward pass of mistral model.
Args:
params (MistralModelParams): The Mistral model parameters.
input_ids (Array): The input sequences to the decoder block.
qk_mask (Array): The qk mask for the attention mechanism, determining which parts of the sequence are allowed to attend to each other.
rotary_values (RotaryValues): Rotary positional embeddings values.
kv_cache_pre (KVCache): The previous KVCache.
Returns:
tuple[Array, KVCache]: A tuple containing the output sequence after mistral lm, and previous KVCache.
"""
embedding, decoder_layers, norm = params
seq = forward_embedding(embedding, input_ids)
seq, kv_cache_pre = forward_decoder(decoder_layers, seq, qk_mask, rotary_values, kv_cache_pre)
seq = forward_rms_norm(norm, seq)
return seq, kv_cache_pre
def test_forward_mistral_model():
pass