Source code for mistral_v0_2.model.mlp_layer

from einshard import einshard
import jax
from jax import Array
from transformers import MistralForCausalLM
from transformers.models.mistral.modeling_mistral import MistralMLP

from ..lib.array_conversion import pt2jax

MLPLayerParams = tuple[Array, Array, Array]

[docs] def convert_mlp_layer_params(mlp_layer: MistralMLP) -> MLPLayerParams: """ Converts PyTorch MLP layer parameters to a MLPLayerParams compatible with JAX. Args: mlp_layer (MistralMLP): The PyTorch MLP layer from which to extract the weights. Returns: MLPLayerParams: The embedding parameters extracted from the PyTorch layer and formatted for compatibility with JAX operations. """ gate_proj = pt2jax(mlp_layer.gate_proj.weight.data.T) up_proj = pt2jax(mlp_layer.up_proj.weight.data.T) down_proj = pt2jax(mlp_layer.down_proj.weight.data.T) return gate_proj, up_proj, down_proj
def convert_back_mlp_layer_params(mlp_layer: MLPLayerParams) -> MistralMLP: # mlp_layer_pt = MistralMLP(config_pt) # TODO: handle config pass
[docs] def shard_mlp_layer_params(params: MLPLayerParams) -> MLPLayerParams: """ Shard the MLPLayerParams params for distributed computing. Args: params (MLPLayerParams): The MLPLayerParams parameters. Returns: MLPLayerParams: The decoder embedding parameters replica for distributed computation across multiple devices. """ gate_proj, up_proj, down_proj = params gate_proj = einshard(gate_proj, 'm f -> m f*') up_proj = einshard(up_proj, 'm f -> m f*') down_proj = einshard(down_proj, 'f m -> f* m') return gate_proj, up_proj, down_proj
[docs] def forward_mlp_layer(params: MLPLayerParams, seq: Array) -> Array: """ Executes the forward pass of MLP. Args: params (MLPLayerParams): The MLP layer parameters. seq (Array): The input sequences. Returns: Array: The output after MLP layer. """ gate_proj, up_proj, down_proj = params ff = jax.nn.silu(seq @ gate_proj) * (seq @ up_proj) seq = ff @ down_proj return seq
def test_forward_mlp_layer(model: MistralForCausalLM) -> None: pass