mistral_v0_2.model.embedding#
- mistral_v0_2.model.embedding.convert_embedding_params(embedding)[source]#
Converts PyTorch embedding parameters to a EmbeddingParams compatible with JAX.
- Parameters:
embedding (TorchEmbedding) – The PyTorch embedding layer from which to extract the weights.
- Returns:
The embedding parameters extracted from the PyTorch layer and formatted for compatibility with JAX operations.
- Return type:
EmbeddingParams
- mistral_v0_2.model.embedding.forward_embedding(params, input_ids)[source]#
Get the embedding with input IDS.
- Parameters:
params (EmbeddingParams) – The embedding parameters.
input_ids (Array) – An array of input IDS to look up the embedding.
- Returns:
The embedding Array of input IDS.
- Return type:
Array
- mistral_v0_2.model.embedding.shard_embedding_params(params)[source]#
Shard the EmbeddingParams params for distributed computing.
- Parameters:
params (EmbeddingParams) – The EmbeddingParams parameters.
- Returns:
The decoder embedding parameters replica for distributed computation across multiple devices.
- Return type:
EmbeddingParams