mistral_v0_2.model.decoder_block#
- mistral_v0_2.model.decoder_block.convert_decoder_block_params(decoder_block)[source]#
Converts decoder block parameters from MistralDecoderLayer(PyTorch tensor) to DecoderBlockParams(JAX Array).
- Parameters:
decoder_block (MistralDecoderLayer) – The decoder block’s MistralDecoderLayer.
- Returns:
The converted decoder block parameters.
- Return type:
DecoderBlockParams
- mistral_v0_2.model.decoder_block.forward_decoder_block(params, seq, qk_mask, rotary_values, kv_cache_cur, kv_cache_pre)[source]#
Executes the forward pass of a decoder block using the specified parameters and input sequence.
- Parameters:
params (DecoderBlockParams) – The decoder block parameters.
seq (Array) – The input sequences to the decoder block.
qk_mask (Array) – The qk mask for the attention mechanism, determining which parts of the sequence are allowed to attend to each other.
rotary_values (RotaryValues) – Rotary positional embeddings values.
kv_cache_cur (KVCache) – The current KVCache.
kv_cache_pre (KVCache) – The previous KVCache.
- Returns:
A tuple containing the output sequence after decoder block, and the updated current and previous KVCache.
- Return type:
tuple[Array, KVCache, KVCache]
- mistral_v0_2.model.decoder_block.shard_decoder_block_params(params)[source]#
Shard the DecoderBlockParams params for distributed computing.
- Parameters:
params (DecoderBlockParams) – The decoder block parameters.
- Returns:
The decoder block parameters modified with tensor parallelism, allowing for distributed computation across multiple devices.
- Return type:
DecoderBlockParams