mistral_v0_2.model.decoder#

mistral_v0_2.model.decoder.convert_decoder_params(layers)[source]#

Converts decoder TorchModuleList layers(PyTorch tensor) to DecoderParams(JAX Array).

Parameters:

layers (TorchModuleList) – Layers.

Returns:

The converted decoder parameters.

Return type:

DecoderParams

mistral_v0_2.model.decoder.forward_decoder(params, seq, qk_mask, rotary_values, kv_cache_pre)[source]#

Executes the forward pass of all decoder blocks.

Parameters:
  • params (DecoderParams) – The decoder parameters.

  • seq (Array) – The input sequences to the decoder block.

  • qk_mask (Array) – The qk mask for the attention mechanism, determining which parts of the sequence are allowed to attend to each other.

  • rotary_values (RotaryValues) – Rotary positional embeddings values.

  • kv_cache_cur (KVCache) – The current KVCache.

  • kv_cache_pre (KVCache) – The previous KVCache.

Returns:

A tuple containing the output sequence after all decoder blocks, and previous KVCache.

Return type:

tuple[Array, KVCache]

mistral_v0_2.model.decoder.shard_decoder_params(layers)[source]#

Shard the DecoderParams params for distributed computing.

Parameters:

params (DecoderParams) – The decoder parameters.

Returns:

The decoder parameters modified with tensor parallelism, allowing for distributed computation across multiple devices.

Return type:

DecoderParams