Jax(Random、Numpy)常用函数
目录Jax
vmap
Array
reshape
Random
PRNGKey
uniform
normal
split
choice
Numpy
expand_dims
linspace
jax.numpy.linalg
dot
matmul
arange
interp
tile
reshape
Jax
jit
jax.jit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)
注:jax.jit 是 JAX 中的一个装饰器,用于将 Python 函数编译为高效的机器代码,以提高运行速度。JIT(Just-In-Time)编译可以加速函数的执行,尤其是在循环或需要多次调用。
>>>jax.jit(lambda x,y : x + y)
<PjitFunction of <function <lambda> at 0x7ea7b402f130>>
>>>jax.jit(lambda x,y : x + y)(1,2) #process jitfunc -> lambda fun
Array(3, dtype=int32, weak_type=True)
>>>@jax.jit
def fun(x,y):
return x + y
>>>fun
<PjitFunction of <function fun at 0x7ea7b402f5b0>>
>>>fun(1,2)
Array(3, dtype=int32, weak_type=True) vmap
jax.vmap(fun, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None)
注:对函数举行向量化处理,通常用于批量处理数据,而不需要显式地编写循环,函数映射调用,区别于pmap,vmap单个设备(CPU或GPU)上处理批量数据,pmap在多个设备(GPU或TPU)上并行处理数据(分布式)
>>>f_xy = lambda x,y : x + y
>>>x = jax.numpy.array([,
])# shape (2, 2)
>>>y = jax.numpy.array([,
])# shape (2, 2)
# in this x and y array, axis 0 is row , axis 1 is col, ref shape index
# in x and y, axis -1 is shape[-1] , axis -2 is shape[-2]
>>>jax.vmap(f_xy,in_axes=(0,0))(x,y) # default out_axes = 0,row ouput
# x row + y row , need x row dim equal y row dim
Array([[ 6,8],
], dtype=int32)
>>>jax.vmap(f_xy,in_axes=(0,0),out_axes=1)(x,y) #show output by col
Array([[ 6,8],
], dtype=int32)
>>>jax.vmap(f_xy,in_axes=(0,1))(x,y)
# x row + y col , need x row's dim equal y col's dim
Array([[ 6,9],
[ 9, 12]], dtype=int32)
>>>jax.vmap(f_xy,in_axes=(0,1),out_axes=1)(x,y) #show output by col
Array([[ 6,9],
[ 9, 12]], dtype=int32)
>>>jax.vmap(f_xy,in_axes=(None,0))(x,y) #no vector x by row or col, x is block
# x block + y row vector, x shape (2,2) , y shape(2,2), need x row equal y row
# return shape(y_dim_2,x_dim_1,x_dim2)
Array([[[ 6,8],
[ 8, 10]],
[[ 8, 10],
]], dtype=int32) ref:Learning about JAX :axes in vmap()
Array
reshape
abstract Array.reshape(*args, order='C')
注:Array对象的实例方法,引用jax.numpy.reshape函数
Random
PRNGKey
jax.random.PRNGKey(seed, *, impl=None)#
注:创建一个 PRNG key,作为天生随机数的种子Seed
eg:
>>>jax.random.PRNGKey(0)
Array(, dtype=uint32) uniform
jax.random.uniform(key, shape=(), dtype=<class 'float'>, minval=0.0, maxval=1.0)
注:在给定的外形(shape)和数据类型(dtype)下,从 [minval, maxval) 区间内采样均匀分布的随机值
>>>k = jax.random.PRNGKey(0)
>>>jax.random.uniform(k,shape=(1,))
Array(, dtype=float32) normal
normal(key, shape=(), dtype=<class 'float'>)
注:在给定的外形shape和浮点数据类型dtype下,采样标准正态分布的随机值
>>>k = jax.random.PRNGKey(0)
>>>jax.random.normal(k,shape=(1,))
Array([-0.20584226], dtype=float32) split
jax.random.split(key, num=2)
注:用于天生伪随机数天生器(PRNG)状态的函数。它允许你从一个现有的 PRNG 状态中天生多个新的状态,从而实现随机数的可重复性和并行性。
>>>k = jax.random.PRNGKey(1)
>>>k1,k2 = jax.random.split(k)
>>>k1
Array(, dtype=uint32)
>>>k2
Array(, dtype=uint32) choice
jax.random.choice(key, a, shape=(), replace=True, p=None, axis=0)
注:从给定数组a中按shape天生随机样本,区别于numpy.random.choice函数。default choice one elem。
>>>k = jax.random.PRNGKey(0)
>>>a = jax.numpy.array()
>>>jax.random.choice(k,a,(10,)) # random no seq
Array(, dtype=int32)
>>>jax.random.choice(k,a,(2,5))
Array([,
], dtype=int32) Numpy
expand_dims
expand_dims(a, axis)
注:为数组a的维度axis增加1维度
>>>arr = jax.numpy.array()
>>>arr.shape
(3,)
>>>jax.numpy.expand_dims(arr,axis=0)
Array([], dtype=int32)
>>>jax.numpy.expand_dims(arr,axis=0).shape
(1, 3)
>>>jax.numpy.expand_dims(arr,axis=1)
Array([,
,
], dtype=int32)
>>>jax.numpy.expand_dims(arr,axis=1).shape
(3, 1) linspace
linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: Literal = False, dtype: DTypeLike | None = None, axis: int = 0, *, device: xc.Device | Sharding | None = None) → Array
注:在给定区间内返回均匀隔断的数字
>>>jax.numpy.linspace(0,1,5)
Array(, dtype=float32) jax.numpy.linalg
jax.numpy.linalg 是 JAX 库中用于线性代数操作的模块,对应numpy.linalg库实现
cholesky
jax.numpy.linalg.cholesky(a, *, upper=False)
注:计算一个正定矩阵A的 Cholesky 分解,得到满足A=L@L.T等式的下三角或上三角矩阵L,@为Python1.5界说的矩阵乘运算(jax.numpy.matmul),L.T为L转置矩阵https://latex.csdn.net/eq?L%5E%7BT%7D。
>>> d = jax.numpy.array([,
])
>>>jax.numpy.linalg.cholesky(d)
Array([,
], dtype=float32)
>>>L = jax.numpy.linalg.cholesky(d)
>>>L@L.T
Array([,
], dtype=float32) eigvalsh
jax.numpy.linalg.eigvalsh(a, UPLO='L')
注:计算 Hermitian 对称矩阵的特征值。对于一个给定的方阵 A,其特征值 λ 和特征向量 v满足以下关系Av=λv。cholesky分解矩阵需满足特征值>0。
>>>jax.numpy.linalg.eigvalsh(jax.numpy.array([,
[-1,1]]))
Array(, dtype=float32) cond
jax.numpy.linalg.cond(x, p=None)
注:用于计算矩阵的条件数(condition number),这是衡量矩阵在数值计算中稳定性的重要指标。高条件数警示需要审慎对待矩阵的计算,尤其是在求解线性方程或举行其他数值计算时,如cholesky分解。
>>>jax.numpy.linalg.cond(jax.numpy.array([,
]))
Array(3., dtype=float32) allclose
jax.numpy.allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False)
注:检查两个数组的元素是否在容差范围内近似相称,cholesky分解矩阵需满足对称性。
>>>A=jax.numpy.array([,
])
>>>jax.numpy.allclose(A,A.T)
Array(True, dtype=bool)
# A 为对称矩阵 dot
dot(a, b, *, precision=None, preferred_element_type=None)
注:用于计算两个数组的点积(dot product),对于一维数组,它计算的是向量的内积;对于二维数组(矩阵),它计算的是矩阵乘积;对于更高维度的数组,它执行的是逐元素的点积,并在最后一个轴上举行求和
[*]对于一维数组(向量):numpy.dot(a, b) 计算的是向量 a 和 b 的点积,结果是一个标量。
[*]对于二维数组(矩阵):numpy.dot(A, B) 计算的是矩阵 A 和 B 的乘积,其中 A 的列数必须与 B 的行数相称。结果是一个新的矩阵。
[*]对于更高维度的数组:numpy.dot() 可以举行更复杂的广播和求和运算,但通常用于计算张量积(tensor product)的某个维度上的和。
>>>jax.numpy.dot(jax.numpy.array(),2)
Array(, dtype=int32)
>>>jax.numpy.dot(jax.numpy.array(),jax.numpy.array())
Array(14, dtype=int32)
>>>jax.numpy.dot(jax.numpy.array([,
]),
jax.numpy.array())
Array(, dtype=int32)
>>>jax.numpy.dot(jax.numpy.array([,
]),
jax.numpy.array([,
]))
Array([[ 9, 12],
], dtype=int32)
>>>a = jax.numpy.zeros((1,3,2))
>>>b = jax.numpy.zeros((1,2,4))
>>>jax.numpy.dot(a,b).shape
(1, 3, 1, 4) #matmul ret (1,3,4) matmul
matmul(a, b, *, precision=None, preferred_element_type=None)#
注:于执行矩阵乘法,也称为 @ 运算符(在 Python 3.5+ 中引入),对于一维数组(向量),它计算的是内积(与 dot 相同);对于二维数组(矩阵),它计算的是矩阵乘积(与 dot 相同);对于更高维度的数组,它执行的是逐元素的矩阵乘法,并生存其他轴
[*]对于一维数组(向量):numpy.matmul(a, b) 通常不被界说为向量之间的运算,除非 a 是一个二维数组(表现多个向量)的单个行或列,而且 b 的外形与之兼容。
[*]对于二维数组(矩阵):numpy.matmul(A, B) 计算的是矩阵 A 和 B 的乘积,其中 A 的列数必须与 B 的行数相称。这与 numpy.dot() 对于二维数组的行为相同。
[*]对于更高维度的数组:numpy.matmul() 遵循爱因斯坦求和约定(Einstein summation convention)的特定规则,允许在差别维度的数组之间执行矩阵乘法。这包括批处理矩阵乘法,其中每个批次独立地举行乘法运算。
>>>jax.numpy.matmul(jax.numpy.array(),jax.numpy.array())
Array(14, dtype=int32)
>>>jax.numpy.matmul(jax.numpy.array([,
]),
jax.numpy.array())
Array(, dtype=int32)
>>>jax.numpy.matmul(jax.numpy.array([,
]),
jax.numpy.array([,
]))
Array([[ 9, 12],
], dtype=int32)
>>>a = jax.numpy.zeros((1,3,2))
>>>b = jax.numpy.zeros((1,2,4))
>>>jax.numpy.matmul(a,b).shape
(1, 3, 4) #dot ret (1,3,1,4) arange
jax.numpy.arange(start, stop=None, step=None, dtype=None, *, device=None)
注:default step 为1,在区间[start,stop)天生步长为1的数组,类似range函数
>>>jax.numpy.arange(0,10,1)
Array(, dtype=int32) interp
interp(x, xp, fp, left=None, right=None, period=None)
注:在xp点列表中线性插值x,线性插值满足https://latex.csdn.net/eq?y%3Dy_%7Bi%7D+%5Cfrac%7By_%7Bi+1%7D-y_%7Bi%7D%7D%7Bx_%7Bi+1%7D-x_%7Bi%7D%7D%28x-x_%7Bi%7D%29%2Cx%5Cepsilon%20%5Bx_%7Bi%7D%2Cx_%7Bi+1%7D%29,xi和xi+1表现xp数组相邻两点,插值x位于两点区间之间,xp点对于y值为fp,线性插值为保持符合fp = fun(xp)两点区间斜率的增量
>>>xp = jax.numpy.arange(0,10,1)
>>>fp = jax.numpy.array(range(0,10,1)) * 2
>>>x = jax.numpy.array()
>>>jax.numpy.interp(x,xp,fp)
Array(, dtype=float32) tile
jax.numpy.tile(A, reps)
注:将A数组按reps重复化天生新Array
a = jax.numpy.array()
>>>jax.numpy.tile(a,2)
Array(, dtype=int32)
>>>jax.numpy.tile(a,(2,))
Array(, dtype=int32)
>>>jax.numpy.tile(a,(1,1))
Array([], dtype=int32)
>>>jax.numpy.tile(a,(2,1)) # repeat axis 0 (row) by 2, repeat axis 1 (col) by 1
Array([,
], dtype=int32) reshape
jax.numpy.reshape(a, shape=None, order='C', *, newshape=Deprecated, copy=None)
注:从界说Array a的shape外形为shape元组(),支持-1,推断dim数值
>>>a = jax.numpy.array([,
])
>>>jax.numpy.reshape(a,6) # equal reshape(a,(6,))
Array(, dtype=int32)
>>>jax.numpy.reshape(a,-1) # equal reshape(a,6)-1 is inferred to be 3
Array(, dtype=int32)
>>>jax.numpy.reshape(a,(-1,2)) # equal reshape(a,(3,2)) , -1 is inferred to be 3
Array([,
,
], dtype=int32)
>>>jax.numpy.reshape(a,(1,-1)) # not (n,) inferred to 2 d
Array([], dtype=int32) meshgrid
jax.numpy.meshgrid(*xi, copy=True, sparse=False, indexing='xy')
注:创建坐标矩阵,将一维坐标向量xi(自变量x、y)转换为对应的二维坐标向量或矩阵,实用于计算网格点上的函数值(因变量z),默认indexing='xy'输出笛卡尔坐标(row为vector),indexing='ij'输出矩阵坐标(col为vector)
>>>x = jax.numpy.array()
>>>y = jax.numpy.array()
>>>jax.numpy.meshgrid(x,y) #default indexing='xy'
,
], dtype=int32),
Array([,
], dtype=int32)]
>>>jax.numpy.meshgrid(x,y,indexing='ij')
,
,
], dtype=int32),
Array([,
,
], dtype=int32)]
>>>xv,yv = jax.numpy.meshgrid(x,y,indexing='xy')
>>>xv
Array([,
], dtype=int32)
>>>yv
Array([,
], dtype=int32)
>>>xv.ravel()
Array(, dtype=int32)
>>>yv.ravel()
Array(, dtype=int32)
#Array.ravel return a view of array (no memory),flatten return a copy of array
自变量x shape(3,) 自变量y shape(2,),对应平面6个点, 对应值因变量z shape为(6,) 6个数值
,二维坐标可视化代码:
import jax
import matplotlib.pyplot as plt
x = jax.numpy.array()
y = jax.numpy.array()
xv,yv = jax.numpy.meshgrid(x,y,indexing='xy')
z = xv + yv
plt.scatter(xv.flatten(), yv.flatten(), c=z, cmap='viridis') #use xv , yv also show similar graph
plt.colorbar(label='u')
plt.xlim(0, 4)
plt.ylim(3, 6)
plt.xlabel('X axis')
plt.ylabel('Y axis')
plt.title('Grid Units Visualization')
plt.show() https://i-blog.csdnimg.cn/direct/672c574309e64780917afe82c2d324bf.png
尝试将点变多:
import jax
import matplotlib.pyplot as plt
x = jax.numpy.linspace(0,10,100)
y = jax.numpy.linspace(0,10,100)
xv,yv = jax.numpy.meshgrid(x,y,indexing='xy')
z = xv + yv
plt.scatter(xv.flatten(), yv.flatten(), c=z, cmap='viridis')
plt.colorbar(label='z')
plt.xlabel('X axis')
plt.ylabel('Y axis')
plt.title('Grid Units Visualization')
plt.show() https://i-blog.csdnimg.cn/direct/583df6258a7b4fa19e4e331d56195275.png
eye
jax.numpy.eye(N, M=None, k=0, dtype=None, *, device=None)
注:用于创建单位矩阵的函数。单位矩阵是一种方阵,其主对角线上的元素为 1,别的元素为 0。
>>>jax.numpy.eye(3)
Array([,
,
], dtype=float32)
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
页:
[1]