kron积盘算mask类别矩阵

打印 上一主题 下一主题

主题 1043|帖子 1043|积分 3129

1. 生成类别矩阵如下


2. pytorch 代码

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. torch.set_printoptions(precision=3, sci_mode=False)
  5. if __name__ == "__main__":
  6.     run_code = 0
  7.     a_matrix = torch.arange(4).reshape(2, 2) + 1
  8.     b_matrix = torch.ones((2, 2))
  9.     print(f"a_matrix=\n{a_matrix}")
  10.     print(f"b_matrix=\n{b_matrix}")
  11.     c_matrix = torch.kron(input=a_matrix, other=b_matrix)
  12.     print(f"c_matrix=\n{c_matrix}")
  13.     d_matrix = torch.arange(9).reshape(3, 3) + 1
  14.     e_matrix = torch.ones((2, 2))
  15.     f_matrix = torch.kron(input=d_matrix, other=e_matrix)
  16.     print(f"d_matrix=\n{d_matrix}")
  17.     print(f"e_matrix=\n{e_matrix}")
  18.     print(f"f_matrix=\n{f_matrix}")
  19.     g_matrix = f_matrix[1:-1, 1:-1]
  20.     print(f"g_matrix=\n{g_matrix}")
复制代码


  • 效果:
  1. a_matrix=
  2. tensor([[1, 2],
  3.         [3, 4]])
  4. b_matrix=
  5. tensor([[1., 1.],
  6.         [1., 1.]])
  7. c_matrix=
  8. tensor([[1., 1., 2., 2.],
  9.         [1., 1., 2., 2.],
  10.         [3., 3., 4., 4.],
  11.         [3., 3., 4., 4.]])
  12. d_matrix=
  13. tensor([[1, 2, 3],
  14.         [4, 5, 6],
  15.         [7, 8, 9]])
  16. e_matrix=
  17. tensor([[1., 1.],
  18.         [1., 1.]])
  19. f_matrix=
  20. tensor([[1., 1., 2., 2., 3., 3.],
  21.         [1., 1., 2., 2., 3., 3.],
  22.         [4., 4., 5., 5., 6., 6.],
  23.         [4., 4., 5., 5., 6., 6.],
  24.         [7., 7., 8., 8., 9., 9.],
  25.         [7., 7., 8., 8., 9., 9.]])
  26. g_matrix=
  27. tensor([[1., 2., 2., 3.],
  28.         [4., 5., 5., 6.],
  29.         [4., 5., 5., 6.],
  30.         [7., 8., 8., 9.]])
复制代码
3. 循环移动矩阵



  • excel 表现

  • pytorch 源码
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import math
  5. torch.set_printoptions(precision=3, sci_mode=False)
  6. class WindowMatrix(object):
  7.     def __init__(self, num_patch=4, size=2):
  8.         self.num_patch = num_patch
  9.         self.size = size
  10.         self.width = self.num_patch
  11.         self.height = self.size * self.size
  12.         self._result = torch.zeros((self.width, self.height))
  13.     @property
  14.     def result(self):
  15.         a_size = int(math.sqrt(self.num_patch))
  16.         a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1
  17.         b_matrix = torch.ones(self.size, self.size)
  18.         self._result = torch.kron(input=a_matrix, other=b_matrix)
  19.         return self._result
  20. class ShiftedWindowMatrix(object):
  21.     def __init__(self, num_patch=9, size=2):
  22.         self.num_patch = num_patch
  23.         self.size = size
  24.         self.width = self.num_patch
  25.         self.height = self.size * self.size
  26.         self._result = torch.zeros((self.width, self.height))
  27.     @property
  28.     def result(self):
  29.         a_size = int(math.sqrt(self.num_patch))
  30.         a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1
  31.         b_matrix = torch.ones(self.size, self.size)
  32.         my_result = torch.kron(input=a_matrix, other=b_matrix)
  33.         self._result = my_result[1:-1, 1:-1]
  34.         return self._result
  35. class RollShiftedWindowMatrix(object):
  36.     def __init__(self, num_patch=9, size=2):
  37.         self.num_patch = num_patch
  38.         self.size = size
  39.         self.width = self.num_patch
  40.         self.height = self.size * self.size
  41.         self._result = torch.zeros((self.width, self.height))
  42.     @property
  43.     def result(self):
  44.         a_size = int(math.sqrt(self.num_patch))
  45.         a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1
  46.         b_matrix = torch.ones(self.size, self.size)
  47.         my_result = torch.kron(input=a_matrix, other=b_matrix)
  48.         my_result = my_result[1:-1, 1:-1]
  49.         roll_result = torch.roll(input=my_result, shifts=(-1, -1), dims=(-1, -2))
  50.         self._result = roll_result
  51.         return self._result
  52. class BackRollShiftedWindowMatrix(object):
  53.     def __init__(self, num_patch=9, size=2):
  54.         self.num_patch = num_patch
  55.         self.size = size
  56.         self.width = self.num_patch
  57.         self.height = self.size * self.size
  58.         self._result = torch.zeros((self.width, self.height))
  59.     @property
  60.     def result(self):
  61.         a_size = int(math.sqrt(self.num_patch))
  62.         a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1
  63.         b_matrix = torch.ones(self.size, self.size)
  64.         my_result = torch.kron(input=a_matrix, other=b_matrix)
  65.         my_result = my_result[1:-1, 1:-1]
  66.         roll_result = torch.roll(input=my_result, shifts=(-1, -1), dims=(-1, -2))
  67.         print(f"roll_result=\n{roll_result}")
  68.         roll_result = torch.roll(input=roll_result, shifts=(1, 1), dims=(-1, -2))
  69.         self._result = roll_result
  70.         return self._result
  71. if __name__ == "__main__":
  72.     run_code = 0
  73.     my_window_matrix = WindowMatrix()
  74.     my_window_matrix_result = my_window_matrix.result
  75.     print(f"my_window_matrix_result=\n{my_window_matrix_result}")
  76.     shifted_window_matrix = ShiftedWindowMatrix()
  77.     shifed_window_matrix_result = shifted_window_matrix.result
  78.     print(f"shifed_window_matrix_result=\n{shifed_window_matrix_result}")
  79.     roll_shifted_window_matrix = RollShiftedWindowMatrix()
  80.     roll_shifed_window_matrix_result = roll_shifted_window_matrix.result
  81.     print(f"roll_shifed_window_matrix_result=\n{roll_shifed_window_matrix_result}")
  82.     Back_roll_shifted_window_matrix = BackRollShiftedWindowMatrix()
  83.     back_roll_shifed_window_matrix_result = Back_roll_shifted_window_matrix.result
  84.     print(f"back_roll_shifed_window_matrix_result=\n{back_roll_shifed_window_matrix_result}")
复制代码


  • 效果:
  1. my_window_matrix_result=
  2. tensor([[1., 1., 2., 2.],
  3.         [1., 1., 2., 2.],
  4.         [3., 3., 4., 4.],
  5.         [3., 3., 4., 4.]])
  6. shifed_window_matrix_result=
  7. tensor([[1., 2., 2., 3.],
  8.         [4., 5., 5., 6.],
  9.         [4., 5., 5., 6.],
  10.         [7., 8., 8., 9.]])
  11. roll_shifed_window_matrix_result=
  12. tensor([[5., 5., 6., 4.],
  13.         [5., 5., 6., 4.],
  14.         [8., 8., 9., 7.],
  15.         [2., 2., 3., 1.]])
  16. roll_result=
  17. tensor([[5., 5., 6., 4.],
  18.         [5., 5., 6., 4.],
  19.         [8., 8., 9., 7.],
  20.         [2., 2., 3., 1.]])
  21. back_roll_shifed_window_matrix_result=
  22. tensor([[1., 2., 2., 3.],
  23.         [4., 5., 5., 6.],
  24.         [4., 5., 5., 6.],
  25.         [7., 8., 8., 9.]])
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

反转基因福娃

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表