mistral_v0_2.model.mlp_layer#
- mistral_v0_2.model.mlp_layer.convert_mlp_layer_params(mlp_layer)[source]#
Converts PyTorch MLP layer parameters to a MLPLayerParams compatible with JAX.
- Parameters:
mlp_layer (MistralMLP) – The PyTorch MLP layer from which to extract the weights.
- Returns:
The embedding parameters extracted from the PyTorch layer and formatted for compatibility with JAX operations.
- Return type:
MLPLayerParams
- mistral_v0_2.model.mlp_layer.forward_mlp_layer(params, seq)[source]#
Executes the forward pass of MLP.
- Parameters:
params (MLPLayerParams) – The MLP layer parameters.
seq (Array) – The input sequences.
- Returns:
The output after MLP layer.
- Return type:
Array
- mistral_v0_2.model.mlp_layer.shard_mlp_layer_params(params)[source]#
Shard the MLPLayerParams params for distributed computing.
- Parameters:
params (MLPLayerParams) – The MLPLayerParams parameters.
- Returns:
The decoder embedding parameters replica for distributed computation across multiple devices.
- Return type:
MLPLayerParams