mistral_v0_2.model.rms_norm#
- mistral_v0_2.model.rms_norm.convert_rms_norm_params(rms_norm)[source]#
Converts PyTorch rms norm parameters to a RMSNormParams compatible with JAX.
- Parameters:
rms_norm (MistralRMSNorm) – The PyTorch rms norm from which to extract the weights.
- Returns:
The rms norm parameters extracted from the PyTorch layer and formatted for compatibility with JAX operations.
- Return type:
RMSNormParams
- mistral_v0_2.model.rms_norm.forward_rms_norm(params, x)[source]#
Executes the forward pass of MLP.
- Parameters:
params (RMSNormParams) – The rms norm parameters.
x (Array) – The input array.
- Returns:
The output after rms norm.
- Return type:
Array
- mistral_v0_2.model.rms_norm.shard_rms_norm_params(params)[source]#
Shard the RMSNormParams params for distributed computing.
- Parameters:
params (RMSNormParams) – The RMSNormParams parameters.
- Returns:
The rms norm parameters replica for distributed computation across multiple devices.
- Return type:
RMSNormParams