mistral_v0_2.model.mlp_layer#

mistral_v0_2.model.mlp_layer.convert_mlp_layer_params(mlp_layer)[source]#

Converts PyTorch MLP layer parameters to a MLPLayerParams compatible with JAX.

Parameters:

mlp_layer (MistralMLP) – The PyTorch MLP layer from which to extract the weights.

Returns:

The embedding parameters extracted from the PyTorch layer and formatted for compatibility with JAX operations.

Return type:

MLPLayerParams

mistral_v0_2.model.mlp_layer.forward_mlp_layer(params, seq)[source]#

Executes the forward pass of MLP.

Parameters:
  • params (MLPLayerParams) – The MLP layer parameters.

  • seq (Array) – The input sequences.

Returns:

The output after MLP layer.

Return type:

Array

mistral_v0_2.model.mlp_layer.shard_mlp_layer_params(params)[source]#

Shard the MLPLayerParams params for distributed computing.

Parameters:

params (MLPLayerParams) – The MLPLayerParams parameters.

Returns:

The decoder embedding parameters replica for distributed computation across multiple devices.

Return type:

MLPLayerParams