Source code for mistral_v0_2.model.attention

import math

import einops as op
from einshard import einshard
import jax
from jax import Array
import jax.numpy as jnp
import torch
from transformers import MistralForCausalLM
from transformers.models.mistral.modeling_mistral import MistralAttention

from .kvcache import KVCache
from .rotary_embedding import RotaryValues, make_rotary_values, forward_rotary_embedding
from ..lib.array_conversion import pt2jax

# TODO: eliminate this
d_model = 4096
n_rep_kv = 4
n_heads_kv = 8
d_k = d_v = 128

AttentionParams = tuple[Array, Array, Array, Array]

[docs] def convert_attention_params(self_attn: MistralAttention) -> AttentionParams: """ Converts the attention parameters from MistralAttention HuggingFace format with PyTorch `tensor` to `jax.Array`. Args: self_attn (MistralAttention): The attention parameters in MistralAttention. Returns: AttentionParams: The attention parameters converted into the AttentionParams with JAX. """ q_proj = self_attn.q_proj.weight.data k_proj = self_attn.k_proj.weight.data v_proj = self_attn.v_proj.weight.data o_proj = self_attn.o_proj.weight.data q_proj_jax = pt2jax(q_proj.T).reshape(d_model, n_heads_kv, n_rep_kv, d_k).transpose(0, 2, 1, 3) k_proj_jax = pt2jax(k_proj.T).reshape(d_model, n_heads_kv, d_k) v_proj_jax = pt2jax(v_proj.T).reshape(d_model, n_heads_kv, d_v) o_proj_jax = pt2jax(o_proj.T).reshape(n_heads_kv, n_rep_kv, d_v, d_model).transpose(1, 0, 2, 3) return q_proj_jax, k_proj_jax, v_proj_jax, o_proj_jax
def convert_back_attention_params(): pass
[docs] def shard_attention_params(params: AttentionParams) -> AttentionParams: """ Shard the attention parameters for distributed computing. Args: params (AttentionParams): The attention parameters. Returns: AttentionParams: The attention parameters modified with tensor parallelism, allowing for distributed computation across multiple devices. """ q_proj, k_proj, v_proj, o_proj = params # q_proj = einshard(q_proj, 'm r h k -> m r h* k') # k_proj = einshard(k_proj, 'm h k -> m h* k') # v_proj = einshard(v_proj, 'm h v -> m h* v') # o_proj = einshard(o_proj, 'r h v m -> r h* v m') q_proj = einshard(q_proj, 'm r h k -> m r h k*') k_proj = einshard(k_proj, 'm h k -> m h k*') v_proj = einshard(v_proj, 'm h v -> m h v*') o_proj = einshard(o_proj, 'r h v m -> r h v* m') return q_proj, k_proj, v_proj, o_proj
[docs] def forward_attention(params: AttentionParams, seq: Array, qk_mask: Array, rotary_values: RotaryValues, kv_cache_cur: KVCache, kv_cache_pre: KVCache) -> tuple[Array, KVCache, KVCache]: """ Performs the forward pass of the attention mechanism using. This function executes the attention mechanism on the input sequence `seq` using the provided attention parameters. Args: params (AttentionParams): The attention parameters. seq (Array): The input sequences on which attention is to be applied. qk_mask (Array): The qk mask for the attention mechanism, determining which parts of the sequence are allowed to attend to each other. rotary_values (RotaryValues): Rotary positional embeddings values. kv_cache_cur (KVCache): The current KVCache. kv_cache_pre (KVCache): The previous KVCache. Returns: tuple[Array, KVCache, KVCache]: A tuple containing the output sequence after applying attention, and the updated current and previous KVCache. """ q_proj_jax, k_proj_jax, v_proj_jax, o_proj_jax = params # for q, the seq is src_seq, # for k and v, the seq is des_seq, # in self_atten the src_ and des_seq are the same # q.shape: (1 batch_size, 4 n_rep_kv, 8 n_head, 6 seq_len ?, 128 k_dimension) # k.shape: (1 batch_size, 8 n_head, 6 seq_len, 128 k_dimension) # v.shape: (1 batch_size, 8 n_head, 6 seq_len, 128 v_dimension) # einsum can use to apply matrix multiplication q = op.einsum(seq, q_proj_jax, 'b s m, m r h k -> b r h s k') k = op.einsum(seq, k_proj_jax, 'b d m, m h k -> b h d k') v = op.einsum(seq, v_proj_jax, 'b d m, m h v -> b h d v') # before self attention, add position information # q.shape: (1 batch_size, 4, 8, 6 seq_len, 128) q = forward_rotary_embedding(q, rotary_values=rotary_values) k = forward_rotary_embedding(k, rotary_values=rotary_values) # KVCache to optimize generation if kv_cache_pre is not None: layer_n = 0 if kv_cache_cur is None else kv_cache_cur.shape[1] k = jnp.concatenate((kv_cache_pre[0, layer_n, ...], k), axis=-2) v = jnp.concatenate((kv_cache_pre[1, layer_n, ...], v), axis=-2) k_cache_cur = k[None, ...] if kv_cache_cur is None else jnp.concatenate((kv_cache_cur[0, ...], k[None, ...]), axis=0) v_cache_cur = v[None, ...] if kv_cache_cur is None else jnp.concatenate((kv_cache_cur[1, ...], v[None, ...]), axis=0) kv_cache_cur = jnp.concatenate((k_cache_cur[None, ...], v_cache_cur[None, ...]), axis=0) # self-attention # (1 batch_size, 4 repetition, 8 head number, 6 seq_len, 6 seq_len) # Scaled Dot-Product Attention as 3.2.1 equation(1) in orginal Transformer paper qk = jnp.einsum('brhsk,bhdk->brhsd', q, k) / math.sqrt(d_k) qk = jax.nn.softmax(qk, where=qk_mask, initial=0.) qkv = jnp.einsum('brhsd,bhdv->brhsv', qk, v) out = jnp.einsum('brhsv,rhvm->bsm', qkv, o_proj_jax) return out, kv_cache_cur, kv_cache_pre
[docs] def test_forward_attention(model: MistralForCausalLM) -> None: """ Tests the forward attention mechanism. This function is designed to validate the functionality and correctness of the attention mechanism with JAX. Args: model (MistralForCausalLM): PyTorch Mistral model to compare with this implementation. Returns: None. """ batch_size = 1 seq_len = 6 self_attn_pt = model.model.layers[0].self_attn seq_pt = torch.rand(batch_size, seq_len, d_model, device=model.device) attention_mask_pt = torch.tril(torch.ones(batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=model.device)) attention_mask_pt_ = torch.where(attention_mask_pt, 0., -torch.inf) out_pt = self_attn_pt(seq_pt, attention_mask=attention_mask_pt_)[0] params = convert_attention_params(self_attn_pt) seq_jax = pt2jax(seq_pt) attention_mask_jax = pt2jax(attention_mask_pt) rotary_values = make_rotary_values(batch_size, seq_len) out_jax, _, _ = forward_attention(params, seq_jax, attention_mask_jax, rotary_values, None, None) assert jnp.allclose(out_jax, pt2jax(out_pt), atol=1e-5)