Source code for mistral_v0_2.model.rms_norm
from einshard import einshard
from jax import Array
import jax.numpy as jnp
import torch
from transformers import MistralForCausalLM
from transformers.models.mistral.modeling_mistral import MistralRMSNorm
from ..lib.array_conversion import jax2pt, pt2jax
# TODO: eliminate this
d_model = 4096
rms_norm_eps = 1e-5
RMSNormParams = Array
[docs]
def convert_rms_norm_params(rms_norm: MistralRMSNorm) -> RMSNormParams:
"""
Converts PyTorch rms norm parameters to a RMSNormParams compatible with JAX.
Args:
rms_norm (MistralRMSNorm): The PyTorch rms norm from which to extract the weights.
Returns:
RMSNormParams: The rms norm parameters extracted from the PyTorch layer and formatted for compatibility with JAX operations.
"""
return pt2jax(rms_norm.weight)
def convert_back_rms_norm_params(rms_norm: RMSNormParams) -> MistralRMSNorm:
rms_norm_pt = MistralRMSNorm(rms_norm.shape[0], rms_norm_eps)
rms_norm_pt.weight = torch.nn.Parameter(jax2pt(rms_norm))
return rms_norm_pt
[docs]
def shard_rms_norm_params(params: RMSNormParams) -> RMSNormParams:
"""
Shard the RMSNormParams params for distributed computing.
Args:
params (RMSNormParams): The RMSNormParams parameters.
Returns:
RMSNormParams: The rms norm parameters replica for distributed computation across multiple devices.
"""
return einshard(params, '... -> * ...')
# Taken from https://github.com/ayaka14732/llama-2-jax/blob/main/lib/llama/rms_norm.py
[docs]
def forward_rms_norm(params: RMSNormParams, x: Array) -> Array:
"""
Executes the forward pass of MLP.
Args:
params (RMSNormParams): The rms norm parameters.
x (Array): The input array.
Returns:
Array: The output after rms norm.
"""
x_rms = jnp.sqrt((x * x).mean(axis=-1, keepdims=True) + rms_norm_eps)
y = x / x_rms * params
return y
[docs]
def test_forward_rms_norm(model: MistralForCausalLM) -> None:
"""
Tests the rsm norm.
Args:
model (MistralForCausalLM): PyTorch Mistral model to compare with this implementation.
Returns:
None.
"""
d_model = 4096
seq_pt = torch.rand(2, 14, d_model)
rms_norm = model.model.norm
out_pt = rms_norm(seq_pt)
out_pt_to_jax = pt2jax(out_pt)
seq_jax = pt2jax(seq_pt)
params = convert_rms_norm_params(rms_norm)
out_jax = forward_rms_norm(params, seq_jax)
assert jnp.allclose(out_pt_to_jax, out_jax, atol=1e-5)