mistral_v0_2.model.mistral_model#

mistral_v0_2.model.mistral_model.convert_mistral_model_params(model)[source]#

Converts MistralModel (PyTorch tensor) to MistralModelParams(JAX Array).

Parameters:

model (MistralModel) – Mistral v0.2 model.

Returns:

The converted Mistral parameters.

Return type:

MistralModelParams

mistral_v0_2.model.mistral_model.forward_mistral_model(params, input_ids, qk_mask, rotary_values, kv_cache_pre)[source]#

Executes the forward pass of mistral model.

Parameters:
  • params (MistralModelParams) – The Mistral model parameters.

  • input_ids (Array) – The input sequences to the decoder block.

  • qk_mask (Array) – The qk mask for the attention mechanism, determining which parts of the sequence are allowed to attend to each other.

  • rotary_values (RotaryValues) – Rotary positional embeddings values.

  • kv_cache_pre (KVCache) – The previous KVCache.

Returns:

A tuple containing the output sequence after mistral lm, and previous KVCache.

Return type:

tuple[Array, KVCache]

mistral_v0_2.model.mistral_model.shard_mistral_model_params(params)[source]#

Shard the MistralModelParams params for distributed computing.

Parameters:

params (MistralModelParams) – The Mistral parameters.

Returns:

The Mistral parameters modified with tensor parallelism, allowing for distributed computation across multiple devices.

Return type:

MistralModelParams