mistral_v0_2.model.embedding#

mistral_v0_2.model.embedding.convert_embedding_params(embedding)[source]#

Converts PyTorch embedding parameters to a EmbeddingParams compatible with JAX.

Parameters:

embedding (TorchEmbedding) – The PyTorch embedding layer from which to extract the weights.

Returns:

The embedding parameters extracted from the PyTorch layer and formatted for compatibility with JAX operations.

Return type:

EmbeddingParams

mistral_v0_2.model.embedding.forward_embedding(params, input_ids)[source]#

Get the embedding with input IDS.

Parameters:
  • params (EmbeddingParams) – The embedding parameters.

  • input_ids (Array) – An array of input IDS to look up the embedding.

Returns:

The embedding Array of input IDS.

Return type:

Array

mistral_v0_2.model.embedding.shard_embedding_params(params)[source]#

Shard the EmbeddingParams params for distributed computing.

Parameters:

params (EmbeddingParams) – The EmbeddingParams parameters.

Returns:

The decoder embedding parameters replica for distributed computation across multiple devices.

Return type:

EmbeddingParams

mistral_v0_2.model.embedding.test_forward_embedding(model)[source]#

Tests the embedding parameters.

Parameters:

model (MistralForCausalLM) – PyTorch Mistral model to compare with this implementation.

Return type:

None

Returns:

None.