Source code for mistral_v0_2.model.rotary_embedding

from typing import NamedTuple

import einops as op
import jax
from jax import Array
import jax.numpy as jnp

# TODO: eliminate this
d_k = 128

# TODO: Mostly taken from https://github.com/kingoflolz/mesh-transformer-jax/blob/master/mesh_transformer/layers.py
# and https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L92
def _make_weights(seq_len: int, d_k: int) -> tuple[Array, Array]:
    inv_freq = 1. / (1000000 ** (jnp.arange(0, d_k, 2) / d_k))
    sinusoid_inp = op.einsum(jnp.arange(seq_len), inv_freq, 'L, j -> L j')
    sin_val = jnp.sin(sinusoid_inp)
    cos_val = jnp.cos(sinusoid_inp)
    sin_val = op.repeat(sin_val, 'L K -> L (i K)', i=2)
    cos_val = op.repeat(cos_val, 'L K -> L (i K)', i=2)
    return sin_val, cos_val

def _rotate_half(x: Array) -> Array:
    x = op.rearrange(x, '... (i x) -> ... i x', i=2)  # split the last dimension: (..., n) -> (..., 2, n // 2)
    x = x[..., ::-1, :]  # reverse dimension -2
    x = x.at[..., 0, :].multiply(-1)  # negate the first half of dimension -2
    x = op.rearrange(x, '... i x -> ... (i x)')  # merge the last two dimensions: (..., 2, n // 2) -> (..., n)
    return x

class RotaryValues(NamedTuple):
    sin_val: Array
    cos_val: Array

def forward_rotary_embedding(m: Array, *, rotary_values: RotaryValues) -> Array:
    sin_val, cos_val = rotary_values
    assert sin_val.dtype == jnp.float32
    assert cos_val.dtype == jnp.float32
    n = _rotate_half(m)
    a = op.einsum(m, cos_val, 'B ... L K, B L K -> B ... L K').astype(m.dtype)
    b = op.einsum(n, sin_val, 'B ... L K, B L K -> B ... L K').astype(m.dtype)
    return a + b

[docs] def make_rotary_values(batch_size: int, seq_len: int) -> RotaryValues: """ Generates sine and cosine values for rotary positional embeddings based on sequence length. Args: batch_size (int): The number of sequences in a batch. seq_len (int): The length of every sequences in a batch. Returns: RotaryValues: Rotary embedding values with sine values, and cosine values. """ sin_val, cos_val = _make_weights(seq_len, d_k) sin_val = jnp.repeat(sin_val[None], batch_size, axis=0) cos_val = jnp.repeat(cos_val[None], batch_size, axis=0) return RotaryValues(sin_val, cos_val)
[docs] def get_rotary_values_at_position(rotary_values: RotaryValues, position: Array) -> RotaryValues: """ Extracts the rotary positional embedding values for a specific position across all sequences in a batch. Args: rotary_values (RotaryValues): The rotary values from which to extract the positional embeddings. position (Array): The position for which to extract the rotary values. Returns: RotaryValues: Rotary embedding values for the specified position. """ sin_val, cos_val = rotary_values sin_val = sin_val[:, position][:, None] cos_val = cos_val[:, position][:, None] rotary_values = RotaryValues(sin_val, cos_val) return rotary_values