Jax(Random、Numpy)常用函数

打印 上一主题 下一主题

主题 701|帖子 701|积分 2103

目录
Jax
 vmap
Array
reshape
Random
PRNGKey
uniform
normal
split
 choice
Numpy
expand_dims
linspace
jax.numpy.linalg[pkg]
dot
matmul
arange
interp 
tile
reshape


Jax

jit

jax.jit(funin_shardings=UnspecifiedValueout_shardings=UnspecifiedValuestatic_argnums=Nonestatic_argnames=Nonedonate_argnums=Nonedonate_argnames=Nonekeep_unused=Falsedevice=Nonebackend=Noneinline=Falseabstracted_axes=None)[source]
注:jax.jit 是 JAX 中的一个装饰器,用于将 Python 函数编译为高效的机器代码,以提高运行速度。JIT(Just-In-Time)编译可以加速函数的执行,尤其是在循环或需要多次调用。
  1. >>>jax.jit(lambda x,y : x + y)
  2. <PjitFunction of <function <lambda> at 0x7ea7b402f130>>
  3. >>>jax.jit(lambda x,y : x + y)(1,2) #process jitfunc -> lambda fun
  4. Array(3, dtype=int32, weak_type=True)
  5. >>>@jax.jit
  6.    def fun(x,y):
  7.         return x + y
  8. >>>fun
  9. <PjitFunction of <function fun at 0x7ea7b402f5b0>>
  10. >>>fun(1,2)
  11. Array(3, dtype=int32, weak_type=True)
复制代码
 vmap

jax.vmap(funin_axes=0out_axes=0axis_name=Noneaxis_size=Nonespmd_axis_name=None)[source]
注:对函数举行向量化处理,通常用于批量处理数据,而不需要显式地编写循环,函数映射调用,区别于pmap,vmap单个设备(CPU或GPU)上处理批量数据,pmap在多个设备(GPU或TPU)上并行处理数据(分布式)
  1. >>>f_xy = lambda x,y : x + y
  2. >>>x = jax.numpy.array([[1, 2],
  3.                         [3, 4]])  # shape (2, 2)
  4. >>>y = jax.numpy.array([[5, 6],
  5.                         [7, 8]])  # shape (2, 2)
  6. # in this x and y array, axis 0 is row , axis 1 is col, ref shape index
  7. # in x and y, axis -1 is shape[-1] , axis -2 is shape[-2]
  8. >>>jax.vmap(f_xy,in_axes=(0,0))(x,y)      # default out_axes = 0,row ouput
  9. # x row + y row , need x row dim equal y row dim
  10. Array([[ 6,  8],
  11.        [10, 12]], dtype=int32)
  12. >>>jax.vmap(f_xy,in_axes=(0,0),out_axes=1)(x,y) #show output by col
  13. Array([[ 6,  8],
  14.        [10, 12]], dtype=int32)
  15. >>>jax.vmap(f_xy,in_axes=(0,1))(x,y)
  16. # x row + y col , need x row's dim equal y col's dim
  17. Array([[ 6,  9],
  18.        [ 9, 12]], dtype=int32)
  19. >>>jax.vmap(f_xy,in_axes=(0,1),out_axes=1)(x,y) #show output by col
  20. Array([[ 6,  9],
  21.        [ 9, 12]], dtype=int32)
  22. >>>jax.vmap(f_xy,in_axes=(None,0))(x,y) #no vector x by row or col, x is block
  23. # x block + y row vector, x shape (2,2) , y shape(2,2), need x row equal y row
  24. # return shape(y_dim_2,x_dim_1,x_dim2)
  25. Array([[[ 6,  8],
  26.         [ 8, 10]],
  27.        [[ 8, 10],
  28.         [10, 12]]], dtype=int32)
复制代码
refearning about JAX :axes in vmap()
Array

reshape

abstract Array.reshape(*argsorder='C')[source]
注:Array对象的实例方法,引用jax.numpy.reshape函数
Random

PRNGKey

jax.random.PRNGKey(seed*impl=None)[source]#
注:创建一个 PRNG key,作为天生随机数的种子Seed
eg:       
  1. >>>jax.random.PRNGKey(0)
  2. Array([0, 0], dtype=uint32)
复制代码
uniform

jax.random.uniform(keyshape=()dtype=<class 'float'>minval=0.0maxval=1.0)[source]
注:在给定的外形(shape)和数据类型(dtype)下,从 [minval, maxval) 区间内采样均匀分布的随机值
  1. >>>k = jax.random.PRNGKey(0)
  2. >>>jax.random.uniform(k,shape=(1,))
  3. Array([0.41845703], dtype=float32)
复制代码
normal

normal(keyshape=()dtype=<class 'float'>)[source]
注:在给定的外形shape和浮点数据类型dtype下,采样标准正态分布的随机值
  1. >>>k = jax.random.PRNGKey(0)
  2. >>>jax.random.normal(k,shape=(1,))
  3. Array([-0.20584226], dtype=float32)
复制代码
split

jax.random.split(keynum=2)[source]
注:用于天生伪随机数天生器(PRNG)状态的函数。它允许你从一个现有的 PRNG 状态中天生多个新的状态,从而实现随机数的可重复性和并行性。 
  1. >>>k = jax.random.PRNGKey(1)
  2. >>>k1,k2 = jax.random.split(k)
  3. >>>k1
  4. Array([2441914641, 1384938218], dtype=uint32)
  5. >>>k2
  6. Array([3819641963, 2025898573], dtype=uint32)
复制代码
 choice

jax.random.choice(keyashape=()replace=Truep=Noneaxis=0)[source]
注:从给定数组a中按shape天生随机样本,区别于numpy.random.choice函数。default choice one elem。
  1. >>>k = jax.random.PRNGKey(0)
  2. >>>a = jax.numpy.array([1,2,3,4,5,6,7,8,9,0])
  3. >>>jax.random.choice(k,a,(10,)) # random no seq
  4. Array([9, 6, 8, 7, 8, 4, 1, 2, 3, 3], dtype=int32)
  5. >>>jax.random.choice(k,a,(2,5))
  6. Array([[9, 6, 8, 7, 8],
  7.        [4, 1, 2, 3, 3]], dtype=int32)
复制代码
Numpy

expand_dims

expand_dims(aaxis)[source]
注:为数组a的维度axis增加1维度
  1. >>>arr = jax.numpy.array([1,2,3])
  2. >>>arr.shape
  3. (3,)
  4. >>>jax.numpy.expand_dims(arr,axis=0)
  5. Array([[1, 2, 3]], dtype=int32)
  6. >>>jax.numpy.expand_dims(arr,axis=0).shape
  7. (1, 3)
  8. >>>jax.numpy.expand_dims(arr,axis=1)
  9. Array([[1],
  10.        [2],
  11.        [3]], dtype=int32)
  12. >>>jax.numpy.expand_dims(arr,axis=1).shape
  13. (3, 1)
复制代码
linspace

linspace(start: ArrayLikestop: ArrayLikenum: int = 50endpoint: bool = Trueretstep: Literal[False] = Falsedtype: DTypeLike | None = Noneaxis: int = 0*device: xc.Device | Sharding | None = None) → Array[source]
注:在给定区间[start,stop]内返回均匀隔断的数字
  1. >>>jax.numpy.linspace(0,1,5)
  2. Array([0.  , 0.25, 0.5 , 0.75, 1.  ], dtype=float32)
复制代码
jax.numpy.linalg[pkg]

jax.numpy.linalg 是 JAX 库中用于线性代数操作的模块,对应numpy.linalg库实现
cholesky

        jax.numpy.linalg.cholesky(a*upper=False)[source]
        注:计算一个正定矩阵A的 Cholesky 分解,得到满足A=L@L.T等式的下三角或上三角矩阵L,@为Python1.5界说的矩阵乘运算(jax.numpy.matmul),L.T为L转置矩阵

  1. >>> d = jax.numpy.array([[2. , 1.],
  2.                          [1. , 2.]])
  3. >>>jax.numpy.linalg.cholesky(d)
  4. Array([[1.4142135 , 0.        ],
  5.        [0.70710677, 1.2247449 ]], dtype=float32)
  6. >>>L = jax.numpy.linalg.cholesky(d)
  7. >>>L@L.T
  8. Array([[1.9999999 , 0.99999994],
  9.        [0.99999994, 2.        ]], dtype=float32)
复制代码
eigvalsh

jax.numpy.linalg.eigvalsh(aUPLO='L')[source]
注:计算 Hermitian 对称矩阵的特征值。对于一个给定的方阵 A,其特征值 λ 和特征向量 v满足以下关系Av=λv。cholesky分解矩阵需满足特征值>0。
  1. >>>jax.numpy.linalg.eigvalsh(jax.numpy.array([[1,-1],
  2.                                               [-1,1]]))
  3. Array([0., 2.], dtype=float32)
复制代码
 cond

jax.numpy.linalg.cond(xp=None)[source]
注:用于计算矩阵的条件数(condition number),这是衡量矩阵在数值计算中稳定性的重要指标。高条件数警示需要审慎对待矩阵的计算,尤其是在求解线性方程或举行其他数值计算时,如cholesky分解。
  1. >>>jax.numpy.linalg.cond(jax.numpy.array([[1,2],
  2.                                           [2,1]]))
  3. Array(3., dtype=float32)
复制代码
allclose

jax.numpy.allclose(abrtol=1e-05atol=1e-08equal_nan=False)[source]
注:检查两个数组的元素是否在容差范围内近似相称,cholesky分解矩阵需满足对称性。
  1. >>>A=jax.numpy.array([[4, 2],
  2.                       [2, 3]])
  3. >>>jax.numpy.allclose(A,A.T)
  4. Array(True, dtype=bool)
  5. # A 为对称矩阵
复制代码
dot

dot(ab*precision=Nonepreferred_element_type=None)[source]
注:用于计算两个数组的点积(dot product),对于一维数组,它计算的是向量的内积;对于二维数组(矩阵),它计算的是矩阵乘积;对于更高维度的数组,它执行的是逐元素的点积,并在最后一个轴上举行求和
   

  • 对于一维数组(向量):numpy.dot(a, b) 计算的是向量 a 和 b 的点积,结果是一个标量。
  • 对于二维数组(矩阵):numpy.dot(A, B) 计算的是矩阵 A 和 B 的乘积,其中 A 的列数必须与 B 的行数相称。结果是一个新的矩阵。
  • 对于更高维度的数组:numpy.dot() 可以举行更复杂的广播和求和运算,但通常用于计算张量积(tensor product)的某个维度上的和。
  1. >>>jax.numpy.dot(jax.numpy.array([1,2,3]),2)
  2. Array([2, 4, 6], dtype=int32)
  3. >>>jax.numpy.dot(jax.numpy.array([1,2,3]),jax.numpy.array([1,2,3]))
  4. Array(14, dtype=int32)
  5. >>>jax.numpy.dot(jax.numpy.array([[1,2,3],
  6.                                   [4,5,6]]),
  7.                   jax.numpy.array([1,2,3]))
  8. Array([14, 32], dtype=int32)
  9. >>>jax.numpy.dot(jax.numpy.array([[1,2],
  10.                                   [4,5]]),
  11.                  jax.numpy.array([[1,2],
  12.                                   [4,5]]))
  13. Array([[ 9, 12],
  14.        [24, 33]], dtype=int32)
  15. >>>a = jax.numpy.zeros((1,3,2))
  16. >>>b = jax.numpy.zeros((1,2,4))
  17. >>>jax.numpy.dot(a,b).shape
  18. (1, 3, 1, 4) #matmul ret (1,3,4)
复制代码
matmul

matmul(ab*precision=Nonepreferred_element_type=None)[source]#
注:于执行矩阵乘法,也称为 @ 运算符(在 Python 3.5+ 中引入),对于一维数组(向量),它计算的是内积(与 dot 相同);对于二维数组(矩阵),它计算的是矩阵乘积(与 dot 相同);对于更高维度的数组,它执行的是逐元素的矩阵乘法,并生存其他轴
   

  • 对于一维数组(向量):numpy.matmul(a, b) 通常不被界说为向量之间的运算,除非 a 是一个二维数组(表现多个向量)的单个行或列,而且 b 的外形与之兼容。
  • 对于二维数组(矩阵):numpy.matmul(A, B) 计算的是矩阵 A 和 B 的乘积,其中 A 的列数必须与 B 的行数相称。这与 numpy.dot() 对于二维数组的行为相同。
  • 对于更高维度的数组:numpy.matmul() 遵循爱因斯坦求和约定(Einstein summation convention)的特定规则,允许在差别维度的数组之间执行矩阵乘法。这包括批处理矩阵乘法,其中每个批次独立地举行乘法运算。
  1. >>>jax.numpy.matmul(jax.numpy.array([1,2,3]),jax.numpy.array([1,2,3]))
  2. Array(14, dtype=int32)
  3. >>>jax.numpy.matmul(jax.numpy.array([[1,2,3],
  4.                                      [4,5,6]]),
  5.                      jax.numpy.array([1,2,3]))
  6. Array([14, 32], dtype=int32)
  7. >>>jax.numpy.matmul(jax.numpy.array([[1,2],
  8.                                      [4,5]]),
  9.                     jax.numpy.array([[1,2],
  10.                                      [4,5]]))
  11. Array([[ 9, 12],
  12.        [24, 33]], dtype=int32)
  13. >>>a = jax.numpy.zeros((1,3,2))
  14. >>>b = jax.numpy.zeros((1,2,4))
  15. >>>jax.numpy.matmul(a,b).shape
  16. (1, 3, 4) #dot ret (1,3,1,4)
复制代码
arange

jax.numpy.arange(startstop=Nonestep=Nonedtype=None*device=None)[source]
注:default step 为1,在区间[start,stop)天生步长为1的数组,类似range函数
  1. >>>jax.numpy.arange(0,10,1)
  2. Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)
复制代码
interp 

interp(xxpfpleft=Noneright=Noneperiod=None)[source]
注:在xp点列表中线性插值x,线性插值满足
,xi和xi+1表现xp数组相邻两点,插值x位于两点区间之间,xp点对于y值为fp,线性插值为保持符合fp = fun(xp)两点区间斜率的增量
  1. >>>xp = jax.numpy.arange(0,10,1)
  2. >>>fp = jax.numpy.array(range(0,10,1)) * 2
  3. >>>x = jax.numpy.array([1,2,3])
  4. >>>jax.numpy.interp(x,xp,fp)
  5. Array([2., 4., 6.], dtype=float32)
复制代码
tile

jax.numpy.tile(Areps)[source]
注:将A数组按reps重复化天生新Array
  1. a = jax.numpy.array([1,2,3])
  2. >>>jax.numpy.tile(a,2)
  3. Array([1, 2, 3, 1, 2, 3], dtype=int32)
  4. >>>jax.numpy.tile(a,(2,))
  5. Array([1, 2, 3, 1, 2, 3], dtype=int32)
  6. >>>jax.numpy.tile(a,(1,1))
  7. Array([[1, 2, 3]], dtype=int32)
  8. >>>jax.numpy.tile(a,(2,1)) # repeat axis 0 (row) by 2, repeat axis 1 (col) by 1
  9. Array([[1, 2, 3],
  10.        [1, 2, 3]], dtype=int32)
复制代码
reshape

jax.numpy.reshape(ashape=Noneorder='C'*newshape=Deprecatedcopy=None)[source]
注:从界说Array a的shape外形为shape元组(),支持-1,推断dim数值
  1. >>>a = jax.numpy.array([[1, 2, 3],
  2.                         [4, 5, 6]])
  3. >>>jax.numpy.reshape(a,6) # equal reshape(a,(6,))
  4. Array([1, 2, 3, 4, 5, 6], dtype=int32)
  5. >>>jax.numpy.reshape(a,-1) # equal reshape(a,6)  -1 is inferred to be 3
  6. Array([1, 2, 3, 4, 5, 6], dtype=int32)
  7. >>>jax.numpy.reshape(a,(-1,2)) # equal reshape(a,(3,2)) , -1 is inferred to be 3
  8. Array([[1, 2],
  9.        [3, 4],
  10.        [5, 6]], dtype=int32)
  11. >>>jax.numpy.reshape(a,(1,-1)) # not (n,) inferred to 2 d
  12. Array([[1, 2, 3, 4, 5, 6]], dtype=int32)
复制代码
meshgrid

jax.numpy.meshgrid(*xicopy=Truesparse=Falseindexing='xy')[source]
注:创建坐标矩阵,将一维坐标向量xi(自变量x、y)转换为对应的二维坐标向量或矩阵,实用于计算网格点上的函数值(因变量z),默认indexing='xy'输出笛卡尔坐标(row为vector),indexing='ij'输出矩阵坐标(col为vector)
  1. >>>x = jax.numpy.array([1,2,3])
  2. >>>y = jax.numpy.array([4,5])
  3. >>>jax.numpy.meshgrid(x,y) #default indexing='xy'
  4. [Array([[1, 2, 3],
  5.         [1, 2, 3]], dtype=int32),
  6. Array([[4, 4, 4],
  7.         [5, 5, 5]], dtype=int32)]
  8. >>>jax.numpy.meshgrid(x,y,indexing='ij')
  9. [Array([[1, 1],
  10.         [2, 2],
  11.         [3, 3]], dtype=int32),
  12. Array([[4, 5],
  13.         [4, 5],
  14.         [4, 5]], dtype=int32)]
  15. >>>xv,yv = jax.numpy.meshgrid(x,y,indexing='xy')
  16. >>>xv
  17. Array([[1, 2, 3],
  18.        [1, 2, 3]], dtype=int32)
  19. >>>yv
  20. Array([[4, 4, 4],
  21.        [5, 5, 5]], dtype=int32)
  22. >>>xv.ravel()
  23. Array([1, 2, 3, 1, 2, 3], dtype=int32)
  24. >>>yv.ravel()
  25. Array([4, 4, 4, 5, 5, 5], dtype=int32)
  26. #Array.ravel return a view of array (no memory),  flatten return a copy of array
复制代码
 自变量x shape(3,) 自变量y shape(2,),对应平面6个点, 对应值因变量z shape为(6,) 6个数值
,二维坐标可视化代码:
  1. import jax
  2. import matplotlib.pyplot as plt
  3. x = jax.numpy.array([1,2,3])
  4. y = jax.numpy.array([4,5])
  5. xv,yv = jax.numpy.meshgrid(x,y,indexing='xy')
  6. z = xv + yv
  7. plt.scatter(xv.flatten(), yv.flatten(), c=z, cmap='viridis') #use xv , yv also show similar graph
  8. plt.colorbar(label='u')
  9. plt.xlim(0, 4)
  10. plt.ylim(3, 6)
  11. plt.xlabel('X axis')
  12. plt.ylabel('Y axis')
  13. plt.title('Grid Units Visualization')
  14. plt.show()
复制代码

尝试将点变多:
  1. import jax
  2. import matplotlib.pyplot as plt
  3. x = jax.numpy.linspace(0,10,100)
  4. y = jax.numpy.linspace(0,10,100)
  5. xv,yv = jax.numpy.meshgrid(x,y,indexing='xy')
  6. z = xv + yv
  7. plt.scatter(xv.flatten(), yv.flatten(), c=z, cmap='viridis')
  8. plt.colorbar(label='z')
  9. plt.xlabel('X axis')
  10. plt.ylabel('Y axis')
  11. plt.title('Grid Units Visualization')
  12. plt.show()
复制代码

 eye

jax.numpy.eye(NM=Nonek=0dtype=None*device=None)[source]
注:用于创建单位矩阵的函数。单位矩阵是一种方阵,其主对角线上的元素为 1,别的元素为 0。
  1. >>>jax.numpy.eye(3)
  2. Array([[1., 0., 0.],
  3.        [0., 1., 0.],
  4.        [0., 0., 1.]], dtype=float32)
复制代码


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

我可以不吃啊

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表