mistral_v0_2.lib.collate_fn

mistral_v0_2.lib.collate_fn#

mistral_v0_2.lib.collate_fn.raw_collate_fn(tokenizer, max_length, batch)[source]#

Prepares and pads sequences and labels from a batch of data for training, handling padding internally.

Parameters:
  • tokenizer (AutoTokenizer) – The tokenizer used for converting text to token IDs.

  • max_length (int) – The maximum length of the sequence after tokenization.

  • batch (list[tuple[str, Any]]) – A list of tuples, where each tuple contains a text string and its associated label.

Returns:

A tuple containing training data with four JAX arrays:
  • The token IDs of the sequences.

  • The attention mask for the sequences (indicating real tokens vs padded tokens).

  • The token IDs for the labels.

  • The attention mask for the labels.

Return type:

DataTrain

mistral_v0_2.lib.collate_fn.test_collate_fn(tokenizer, max_length, batch)[source]#

Prepares and pads sequences and labels from a batch of data for testing, handling padding internally.

Parameters:
  • tokenizer (AutoTokenizer) – The tokenizer used for converting text to token IDs.

  • max_length (int) – The maximum length of the sequence after tokenization.

  • batch (list[tuple[str, Any]]) – A list of tuples, where each tuple contains a text string and its associated label.

Returns:

A tuple containing testing data with three JAX arrays:
  • The token IDs of the sequences.

  • The attention mask for the sequences (indicating real tokens vs padded tokens).

  • The token IDs for the labels.

Return type:

DataTest