mistral_v0_2.model.rms_norm#

mistral_v0_2.model.rms_norm.convert_rms_norm_params(rms_norm)[source]#

Converts PyTorch rms norm parameters to a RMSNormParams compatible with JAX.

Parameters:

rms_norm (MistralRMSNorm) – The PyTorch rms norm from which to extract the weights.

Returns:

The rms norm parameters extracted from the PyTorch layer and formatted for compatibility with JAX operations.

Return type:

RMSNormParams

mistral_v0_2.model.rms_norm.forward_rms_norm(params, x)[source]#

Executes the forward pass of MLP.

Parameters:
  • params (RMSNormParams) – The rms norm parameters.

  • x (Array) – The input array.

Returns:

The output after rms norm.

Return type:

Array

mistral_v0_2.model.rms_norm.shard_rms_norm_params(params)[source]#

Shard the RMSNormParams params for distributed computing.

Parameters:

params (RMSNormParams) – The RMSNormParams parameters.

Returns:

The rms norm parameters replica for distributed computation across multiple devices.

Return type:

RMSNormParams

mistral_v0_2.model.rms_norm.test_forward_rms_norm(model)[source]#

Tests the rsm norm.

Parameters:

model (MistralForCausalLM) – PyTorch Mistral model to compare with this implementation.

Return type:

None

Returns:

None.