Source code for mistral_v0_2.lib.array_conversion
from jax import Array
import jax.numpy as jnp
import numpy as np
import torch
# PyTorch -> NumPy -> JAX
# JAX -> NumPy -> PyTorch
[docs]
def pt2np(arr: torch.Tensor) -> np.ndarray:
'''
Converts a PyTorch array into a NumPy array.
Args:
x (torch.Tensor): PyTorch array to convert.
Returns:
np.ndarray: Converted NumPy array.
'''
with torch.no_grad():
return arr.cpu().numpy()
[docs]
def np2jax(arr: np.ndarray) -> Array:
'''
Converts a NumPy array into a JAX array.
Args:
x (np.ndarray): NumPy array to convert.
Returns:
Array: Converted jax.Array.
'''
return jnp.asarray(arr)
[docs]
def pt2jax(arr: torch.Tensor) -> Array:
'''
Converts a PyTorch array into a JAX array. The process involves converting the PyTorch tensor to a NumPy array first, then to JAX array.
Args:
x (torch.Tensor): PyTorch array to convert.
Returns:
Array: Converted jax.Array.
'''
with torch.no_grad():
return np2jax(pt2np(arr))
[docs]
def jax2np(arr: Array) -> np.ndarray:
'''
Converts a JAX array into a NumPy array.
Args:
x (Array): JAX array to convert.
Returns:
np.ndarray: Converted NumPy array.
'''
return np.asarray(arr).copy()
[docs]
def jax2np_noncopy(arr: Array) -> np.ndarray:
'''
Converts a JAX array into a NumPy array. The conversion process tries to avoid unnecessary copying when possible.
Args:
x (Array): JAX array to convert.
Returns:
np.ndarray: Converted NumPy array.
'''
return np.asarray(arr)
[docs]
def np2pt(arr: np.ndarray) -> torch.Tensor:
'''
Converts a NumPy array into a PyTorch tensor.
Args:
x (np.ndarray): NumPy array to convert.
Returns:
torch.Tensor: Converted tensor.
'''
return torch.from_numpy(arr)
[docs]
def jax2pt(arr: Array) -> torch.Tensor:
'''
Converts a JAX array into a PyTorch tensor.
Args:
x (Array): JAX array to convert.
Returns:
torch.Tensor: Converted tensor.
'''
return np2pt(jax2np(arr))
[docs]
def jax2pt_noncopy(arr: Array) -> torch.Tensor:
'''
Converts a JAX array into a PyTorch tensor. The conversion process tries to avoid unnecessary copying when possible.
Args:
x (Array): JAX array to convert.
Returns:
torch.Tensor: Converted tensor.
'''
return np2pt(jax2np_noncopy(arr))