1. 生成类别矩阵如下
2. pytorch 代码
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- torch.set_printoptions(precision=3, sci_mode=False)
- if __name__ == "__main__":
- run_code = 0
- a_matrix = torch.arange(4).reshape(2, 2) + 1
- b_matrix = torch.ones((2, 2))
- print(f"a_matrix=\n{a_matrix}")
- print(f"b_matrix=\n{b_matrix}")
- c_matrix = torch.kron(input=a_matrix, other=b_matrix)
- print(f"c_matrix=\n{c_matrix}")
- d_matrix = torch.arange(9).reshape(3, 3) + 1
- e_matrix = torch.ones((2, 2))
- f_matrix = torch.kron(input=d_matrix, other=e_matrix)
- print(f"d_matrix=\n{d_matrix}")
- print(f"e_matrix=\n{e_matrix}")
- print(f"f_matrix=\n{f_matrix}")
- g_matrix = f_matrix[1:-1, 1:-1]
- print(f"g_matrix=\n{g_matrix}")
复制代码
- a_matrix=
- tensor([[1, 2],
- [3, 4]])
- b_matrix=
- tensor([[1., 1.],
- [1., 1.]])
- c_matrix=
- tensor([[1., 1., 2., 2.],
- [1., 1., 2., 2.],
- [3., 3., 4., 4.],
- [3., 3., 4., 4.]])
- d_matrix=
- tensor([[1, 2, 3],
- [4, 5, 6],
- [7, 8, 9]])
- e_matrix=
- tensor([[1., 1.],
- [1., 1.]])
- f_matrix=
- tensor([[1., 1., 2., 2., 3., 3.],
- [1., 1., 2., 2., 3., 3.],
- [4., 4., 5., 5., 6., 6.],
- [4., 4., 5., 5., 6., 6.],
- [7., 7., 8., 8., 9., 9.],
- [7., 7., 8., 8., 9., 9.]])
- g_matrix=
- tensor([[1., 2., 2., 3.],
- [4., 5., 5., 6.],
- [4., 5., 5., 6.],
- [7., 8., 8., 9.]])
复制代码 3. 循环移动矩阵
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import math
- torch.set_printoptions(precision=3, sci_mode=False)
- class WindowMatrix(object):
- def __init__(self, num_patch=4, size=2):
- self.num_patch = num_patch
- self.size = size
- self.width = self.num_patch
- self.height = self.size * self.size
- self._result = torch.zeros((self.width, self.height))
- @property
- def result(self):
- a_size = int(math.sqrt(self.num_patch))
- a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1
- b_matrix = torch.ones(self.size, self.size)
- self._result = torch.kron(input=a_matrix, other=b_matrix)
- return self._result
- class ShiftedWindowMatrix(object):
- def __init__(self, num_patch=9, size=2):
- self.num_patch = num_patch
- self.size = size
- self.width = self.num_patch
- self.height = self.size * self.size
- self._result = torch.zeros((self.width, self.height))
- @property
- def result(self):
- a_size = int(math.sqrt(self.num_patch))
- a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1
- b_matrix = torch.ones(self.size, self.size)
- my_result = torch.kron(input=a_matrix, other=b_matrix)
- self._result = my_result[1:-1, 1:-1]
- return self._result
- class RollShiftedWindowMatrix(object):
- def __init__(self, num_patch=9, size=2):
- self.num_patch = num_patch
- self.size = size
- self.width = self.num_patch
- self.height = self.size * self.size
- self._result = torch.zeros((self.width, self.height))
- @property
- def result(self):
- a_size = int(math.sqrt(self.num_patch))
- a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1
- b_matrix = torch.ones(self.size, self.size)
- my_result = torch.kron(input=a_matrix, other=b_matrix)
- my_result = my_result[1:-1, 1:-1]
- roll_result = torch.roll(input=my_result, shifts=(-1, -1), dims=(-1, -2))
- self._result = roll_result
- return self._result
- class BackRollShiftedWindowMatrix(object):
- def __init__(self, num_patch=9, size=2):
- self.num_patch = num_patch
- self.size = size
- self.width = self.num_patch
- self.height = self.size * self.size
- self._result = torch.zeros((self.width, self.height))
- @property
- def result(self):
- a_size = int(math.sqrt(self.num_patch))
- a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1
- b_matrix = torch.ones(self.size, self.size)
- my_result = torch.kron(input=a_matrix, other=b_matrix)
- my_result = my_result[1:-1, 1:-1]
- roll_result = torch.roll(input=my_result, shifts=(-1, -1), dims=(-1, -2))
- print(f"roll_result=\n{roll_result}")
- roll_result = torch.roll(input=roll_result, shifts=(1, 1), dims=(-1, -2))
- self._result = roll_result
- return self._result
- if __name__ == "__main__":
- run_code = 0
- my_window_matrix = WindowMatrix()
- my_window_matrix_result = my_window_matrix.result
- print(f"my_window_matrix_result=\n{my_window_matrix_result}")
- shifted_window_matrix = ShiftedWindowMatrix()
- shifed_window_matrix_result = shifted_window_matrix.result
- print(f"shifed_window_matrix_result=\n{shifed_window_matrix_result}")
- roll_shifted_window_matrix = RollShiftedWindowMatrix()
- roll_shifed_window_matrix_result = roll_shifted_window_matrix.result
- print(f"roll_shifed_window_matrix_result=\n{roll_shifed_window_matrix_result}")
- Back_roll_shifted_window_matrix = BackRollShiftedWindowMatrix()
- back_roll_shifed_window_matrix_result = Back_roll_shifted_window_matrix.result
- print(f"back_roll_shifed_window_matrix_result=\n{back_roll_shifed_window_matrix_result}")
复制代码
- my_window_matrix_result=
- tensor([[1., 1., 2., 2.],
- [1., 1., 2., 2.],
- [3., 3., 4., 4.],
- [3., 3., 4., 4.]])
- shifed_window_matrix_result=
- tensor([[1., 2., 2., 3.],
- [4., 5., 5., 6.],
- [4., 5., 5., 6.],
- [7., 8., 8., 9.]])
- roll_shifed_window_matrix_result=
- tensor([[5., 5., 6., 4.],
- [5., 5., 6., 4.],
- [8., 8., 9., 7.],
- [2., 2., 3., 1.]])
- roll_result=
- tensor([[5., 5., 6., 4.],
- [5., 5., 6., 4.],
- [8., 8., 9., 7.],
- [2., 2., 3., 1.]])
- back_roll_shifed_window_matrix_result=
- tensor([[1., 2., 2., 3.],
- [4., 5., 5., 6.],
- [4., 5., 5., 6.],
- [7., 8., 8., 9.]])
复制代码 免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |