mistral_v0_2.model.attention#
- mistral_v0_2.model.attention.convert_attention_params(self_attn)[source]#
Converts the attention parameters from MistralAttention HuggingFace format with PyTorch tensor to jax.Array.
- Parameters:
self_attn (MistralAttention) – The attention parameters in MistralAttention.
- Returns:
The attention parameters converted into the AttentionParams with JAX.
- Return type:
AttentionParams
- mistral_v0_2.model.attention.forward_attention(params, seq, qk_mask, rotary_values, kv_cache_cur, kv_cache_pre)[source]#
Performs the forward pass of the attention mechanism using.
This function executes the attention mechanism on the input sequence seq using the provided attention parameters.
- Parameters:
params (AttentionParams) – The attention parameters.
seq (Array) – The input sequences on which attention is to be applied.
qk_mask (Array) – The qk mask for the attention mechanism, determining which parts of the sequence are allowed to attend to each other.
rotary_values (RotaryValues) – Rotary positional embeddings values.
kv_cache_cur (KVCache) – The current KVCache.
kv_cache_pre (KVCache) – The previous KVCache.
- Returns:
A tuple containing the output sequence after applying attention, and the updated current and previous KVCache.
- Return type:
tuple[Array, KVCache, KVCache]
- mistral_v0_2.model.attention.shard_attention_params(params)[source]#
Shard the attention parameters for distributed computing.
- Parameters:
params (AttentionParams) – The attention parameters.
- Returns:
The attention parameters modified with tensor parallelism, allowing for distributed computation across multiple devices.
- Return type:
AttentionParams
- mistral_v0_2.model.attention.test_forward_attention(model)[source]#
Tests the forward attention mechanism.
This function is designed to validate the functionality and correctness of the attention mechanism with JAX.
- Parameters:
model (MistralForCausalLM) – PyTorch Mistral model to compare with this implementation.
- Return type:
None- Returns:
None.