mistral_v0_2.lib.array_conversion#
- mistral_v0_2.lib.array_conversion.jax2np(arr)[source]#
Converts a JAX array into a NumPy array.
- Parameters:
x (Array) – JAX array to convert.
- Returns:
Converted NumPy array.
- Return type:
np.ndarray
- mistral_v0_2.lib.array_conversion.jax2np_noncopy(arr)[source]#
Converts a JAX array into a NumPy array. The conversion process tries to avoid unnecessary copying when possible.
- Parameters:
x (Array) – JAX array to convert.
- Returns:
Converted NumPy array.
- Return type:
np.ndarray
- mistral_v0_2.lib.array_conversion.jax2pt(arr)[source]#
Converts a JAX array into a PyTorch tensor.
- Parameters:
x (Array) – JAX array to convert.
- Returns:
Converted tensor.
- Return type:
- mistral_v0_2.lib.array_conversion.jax2pt_noncopy(arr)[source]#
Converts a JAX array into a PyTorch tensor. The conversion process tries to avoid unnecessary copying when possible.
- Parameters:
x (Array) – JAX array to convert.
- Returns:
Converted tensor.
- Return type:
- mistral_v0_2.lib.array_conversion.np2jax(arr)[source]#
Converts a NumPy array into a JAX array.
- Parameters:
x (np.ndarray) – NumPy array to convert.
- Returns:
Converted jax.Array.
- Return type:
Array
- mistral_v0_2.lib.array_conversion.np2pt(arr)[source]#
Converts a NumPy array into a PyTorch tensor.
- Parameters:
x (np.ndarray) – NumPy array to convert.
- Returns:
Converted tensor.
- Return type:
- mistral_v0_2.lib.array_conversion.pt2jax(arr)[source]#
Converts a PyTorch array into a JAX array. The process involves converting the PyTorch tensor to a NumPy array first, then to JAX array.
- Parameters:
x (torch.Tensor) – PyTorch array to convert.
- Returns:
Converted jax.Array.
- Return type:
Array
- mistral_v0_2.lib.array_conversion.pt2np(arr)[source]#
Converts a PyTorch array into a NumPy array.
- Parameters:
x (torch.Tensor) – PyTorch array to convert.
- Returns:
Converted NumPy array.
- Return type:
np.ndarray