mistral_v0_2.lib.array_conversion

mistral_v0_2.lib.array_conversion#

mistral_v0_2.lib.array_conversion.jax2np(arr)[source]#

Converts a JAX array into a NumPy array.

Parameters:

x (Array) – JAX array to convert.

Returns:

Converted NumPy array.

Return type:

np.ndarray

mistral_v0_2.lib.array_conversion.jax2np_noncopy(arr)[source]#

Converts a JAX array into a NumPy array. The conversion process tries to avoid unnecessary copying when possible.

Parameters:

x (Array) – JAX array to convert.

Returns:

Converted NumPy array.

Return type:

np.ndarray

mistral_v0_2.lib.array_conversion.jax2pt(arr)[source]#

Converts a JAX array into a PyTorch tensor.

Parameters:

x (Array) – JAX array to convert.

Returns:

Converted tensor.

Return type:

torch.Tensor

mistral_v0_2.lib.array_conversion.jax2pt_noncopy(arr)[source]#

Converts a JAX array into a PyTorch tensor. The conversion process tries to avoid unnecessary copying when possible.

Parameters:

x (Array) – JAX array to convert.

Returns:

Converted tensor.

Return type:

torch.Tensor

mistral_v0_2.lib.array_conversion.np2jax(arr)[source]#

Converts a NumPy array into a JAX array.

Parameters:

x (np.ndarray) – NumPy array to convert.

Returns:

Converted jax.Array.

Return type:

Array

mistral_v0_2.lib.array_conversion.np2pt(arr)[source]#

Converts a NumPy array into a PyTorch tensor.

Parameters:

x (np.ndarray) – NumPy array to convert.

Returns:

Converted tensor.

Return type:

torch.Tensor

mistral_v0_2.lib.array_conversion.pt2jax(arr)[source]#

Converts a PyTorch array into a JAX array. The process involves converting the PyTorch tensor to a NumPy array first, then to JAX array.

Parameters:

x (torch.Tensor) – PyTorch array to convert.

Returns:

Converted jax.Array.

Return type:

Array

mistral_v0_2.lib.array_conversion.pt2np(arr)[source]#

Converts a PyTorch array into a NumPy array.

Parameters:

x (torch.Tensor) – PyTorch array to convert.

Returns:

Converted NumPy array.

Return type:

np.ndarray