Source code for mistral_v0_2.model.embedding
from einshard import einshard
from jax import Array
import jax.numpy as jnp
import torch
from torch.nn import Embedding as TorchEmbedding
from transformers import MistralForCausalLM
from ..lib.array_conversion import pt2jax
EmbeddingParams = Array
[docs]
def convert_embedding_params(embedding: TorchEmbedding) -> EmbeddingParams:
"""
Converts PyTorch embedding parameters to a EmbeddingParams compatible with JAX.
Args:
embedding (TorchEmbedding): The PyTorch embedding layer from which to extract the weights.
Returns:
EmbeddingParams: The embedding parameters extracted from the PyTorch layer and formatted for compatibility with JAX operations.
"""
return pt2jax(embedding.weight.data)
def convert_back_embedding_params():
pass
[docs]
def shard_embedding_params(params: EmbeddingParams) -> EmbeddingParams:
"""
Shard the EmbeddingParams params for distributed computing.
Args:
params (EmbeddingParams): The EmbeddingParams parameters.
Returns:
EmbeddingParams: The decoder embedding parameters replica for distributed computation across multiple devices.
"""
return einshard(params, '... -> * ...')
[docs]
def forward_embedding(params: EmbeddingParams, input_ids: Array) -> Array:
"""
Get the embedding with input IDS.
Args:
params (EmbeddingParams): The embedding parameters.
input_ids (Array): An array of input IDS to look up the embedding.
Returns:
Array: The embedding Array of input IDS.
"""
return params[input_ids]
[docs]
def test_forward_embedding(model: MistralForCausalLM) -> None:
"""
Tests the embedding parameters.
Args:
model (MistralForCausalLM): PyTorch Mistral model to compare with this implementation.
Returns:
None.
"""
embedding_pt = model.model.embed_tokens
input_ids_pt = torch.tensor([1, 20, 3, 5, 2, 7], dtype=torch.int32)
result_pt = embedding_pt(input_ids_pt)
result_pt_to_jax = pt2jax(result_pt)
params = convert_embedding_params(embedding_pt)
input_ids_jax = pt2jax(input_ids_pt)
result_jax = forward_embedding(params, input_ids_jax)
assert jnp.allclose(result_pt_to_jax, result_jax, atol=1e-5)