Source code for mistral_v0_2.model.decoder
from jax import Array
from torch.nn import ModuleList as TorchModuleList
from .decoder_block import DecoderBlockParams, convert_decoder_block_params, forward_decoder_block, shard_decoder_block_params
from .kvcache import KVCache
from .rotary_embedding import RotaryValues
DecoderParams = list[DecoderBlockParams]
[docs]
def convert_decoder_params(layers: TorchModuleList) -> DecoderParams:
"""
Converts decoder TorchModuleList layers(PyTorch tensor) to DecoderParams(JAX Array).
Args:
layers (TorchModuleList): Layers.
Returns:
DecoderParams: The converted decoder parameters.
"""
return [convert_decoder_block_params(layer) for layer in layers]
def convert_back_decoder_params():
raise NotImplementedError
[docs]
def shard_decoder_params(layers: DecoderParams) -> DecoderParams:
"""
Shard the DecoderParams params for distributed computing.
Args:
params (DecoderParams): The decoder parameters.
Returns:
DecoderParams: The decoder parameters modified with tensor parallelism, allowing for distributed computation across multiple devices.
"""
return [shard_decoder_block_params(layer) for layer in layers]
[docs]
def forward_decoder(params: DecoderParams, seq: Array, qk_mask: Array, rotary_values: RotaryValues, kv_cache_pre: KVCache) -> tuple[Array, KVCache]:
"""
Executes the forward pass of all decoder blocks.
Args:
params (DecoderParams): The decoder parameters.
seq (Array): The input sequences to the decoder block.
qk_mask (Array): The qk mask for the attention mechanism, determining which parts of the sequence are allowed to attend to each other.
rotary_values (RotaryValues): Rotary positional embeddings values.
kv_cache_cur (KVCache): The current KVCache.
kv_cache_pre (KVCache): The previous KVCache.
Returns:
tuple[Array, KVCache]: A tuple containing the output sequence after all decoder blocks, and previous KVCache.
"""
# TODO: jax.lax.scan
kv_cache_cur = None
for param in params:
seq, kv_cache_cur, kv_cache_pre = forward_decoder_block(param, seq, qk_mask, rotary_values, kv_cache_cur, kv_cache_pre)
kv_cache_pre = kv_cache_cur
return seq, kv_cache_pre
def test_forward_decoder():
pass