mistral_v0_2.model.mistral_lm#
- mistral_v0_2.model.mistral_lm.convert_mistral_lm_params(model)[source]#
Converts MistralForCausalLM (PyTorch tensor) to MistralLMParams(JAX Array).
- Parameters:
model (MistralForCausalLM) – Mistral LM.
- Returns:
The converted Mistral lm parameters.
- Return type:
MistralLMParams
- mistral_v0_2.model.mistral_lm.forward_mistral_lm(params, input_ids, qk_mask, rotary_values, kv_cache_pre)[source]#
Executes the forward pass of mistral lm.
- Parameters:
params (MistralLMParams) – The decoder parameters.
input_ids (Array) – The input sequences to the decoder block.
qk_mask (Array) – The qk mask for the attention mechanism, determining which parts of the sequence are allowed to attend to each other.
rotary_values (RotaryValues) – Rotary positional embeddings values.
kv_cache_pre (KVCache) – The previous KVCache.
- Returns:
A tuple containing the output sequence after mistral lm, and previous KVCache.
- Return type:
tuple[Array, KVCache]
- mistral_v0_2.model.mistral_lm.shard_mistral_lm_params(params)[source]#
Shard the MistralLMParams params for distributed computing.
- Parameters:
params (MistralLMParams) – The Mistral LM parameters.
- Returns:
The Mistral LM parameters modified with tensor parallelism, allowing for distributed computation across multiple devices.
- Return type:
MistralLMParams
- mistral_v0_2.model.mistral_lm.test_forward_mistral_lm(model)[source]#
Tests the forward Mistral LM.
This function is designed to validate the functionality and correctness of the Mistral LM with JAX.
- Parameters:
model (MistralForCausalLM) – PyTorch Mistral model to compare with this implementation.
- Return type:
None- Returns:
None.