Source code for mistral_v0_2.lib.collate_fn
from typing import Any
from jax import Array
import jax.numpy as jnp
from transformers import AutoTokenizer, MistralForCausalLM
# training data: seq_ids, seq_mask, labels, labels_mask
# testing data: seq_ids, seq_mask, labels
DataTrain = tuple[Array, Array, Array, Array]
DataTest = tuple[Array, Array, Array]
[docs]
def raw_collate_fn(tokenizer: AutoTokenizer, max_length: int, batch: list[tuple[str, Any]]) -> DataTrain:
"""
Prepares and pads sequences and labels from a batch of data for training, handling padding internally.
Args:
tokenizer (AutoTokenizer): The tokenizer used for converting text to token IDs.
max_length (int): The maximum length of the sequence after tokenization.
batch (list[tuple[str, Any]]): A list of tuples, where each tuple contains a text string and its associated label.
Returns:
DataTrain: A tuple containing training data with four JAX arrays:
- The token IDs of the sequences.
- The attention mask for the sequences (indicating real tokens vs padded tokens).
- The token IDs for the labels.
- The attention mask for the labels.
"""
# after tokenizer(sentence),1(bos_token_id) is added at begining
# bos_id = tokenizer.bos_token_id
pad_id = tokenizer.pad_token_id
all_seq_ids = []
all_seq_mask = []
all_label_ids = []
all_label_mask = []
for input_, label_ in batch:
seq_ids = tokenizer(input_, return_attention_mask=False).input_ids
label_ids = tokenizer(str(label_), return_attention_mask=False).input_ids
seq_len = len(seq_ids)
# add bos, eos and padding tokens
if seq_len < max_length:
pad_len = max_length - seq_len
seq_ids = seq_ids + [pad_id] * pad_len
seq_mask = [True] * seq_len + [False] * pad_len
else:
seq_ids = seq_ids[:max_length]
seq_mask = [True] * max_length
label_len = len(label_ids)
if label_len < max_length:
pad_label_len = max_length - label_len
label_ids = label_ids + [pad_id] * pad_label_len
label_mask = [True] * label_len + [False] * pad_label_len
else:
label_ids = label_ids[:max_length]
label_mask = [True] * max_length
all_seq_ids.append(seq_ids)
all_seq_mask.append(seq_mask)
all_label_ids.append(label_ids)
all_label_mask.append(label_mask)
seq_ids = jnp.array(all_seq_ids, dtype=jnp.uint16)
seq_mask = jnp.array(all_seq_mask, dtype=jnp.bool_)
labels_ids = jnp.array(all_label_ids, dtype=jnp.uint16)
labels_mask = jnp.array(all_label_mask, dtype=jnp.bool_)
return seq_ids, seq_mask, labels_ids, labels_mask
[docs]
def test_collate_fn(tokenizer: AutoTokenizer, max_length: int, batch: list[tuple[str, Any]]) -> DataTest:
"""
Prepares and pads sequences and labels from a batch of data for testing, handling padding internally.
Args:
tokenizer (AutoTokenizer): The tokenizer used for converting text to token IDs.
max_length (int): The maximum length of the sequence after tokenization.
batch (list[tuple[str, Any]]): A list of tuples, where each tuple contains a text string and its associated label.
Returns:
DataTest: A tuple containing testing data with three JAX arrays:
- The token IDs of the sequences.
- The attention mask for the sequences (indicating real tokens vs padded tokens).
- The token IDs for the labels.
"""
bos_id = tokenizer.bos_token_id
eos_id = tokenizer.eos_token_id
pad_id = tokenizer.pad_token_id
all_seq_ids = []
all_seq_mask = []
all_label = []
for input_, label_ in batch:
seq_ids = tokenizer(input_, return_attention_mask=False).input_ids
label_ids = tokenizer(label_, return_attention_mask=False).input_ids
seq_len = len(seq_ids) + 2
# add bos, eos and padding tokens
if seq_len < max_length:
pad_len = max_length - seq_len
seq_ids = [bos_id] + seq_ids + [eos_id] + [pad_id] * pad_len
seq_mask = [True] * seq_len + [False] * pad_len
else:
seq_ids = [bos_id] + seq_ids[:max_length]
seq_mask = [True] * max_length
all_seq_ids.append(seq_ids)
all_seq_mask.append(seq_mask)
all_label.append(label_)
seq_ids = jnp.array(all_seq_ids, dtype=jnp.uint16)
seq_mask = jnp.array(all_seq_mask, dtype=jnp.bool_)
labels = jnp.array(all_label_ids, dtype=jnp.uint16)
return seq_ids, seq_mask, labels