ToB企服应用市场:ToB评测及商务社交产业平台

标题: samout 结构再优化 收敛速度再加速 [打印本页]

作者: 涛声依旧在    时间: 2024-7-16 02:55
标题: samout 结构再优化 收敛速度再加速
代码

  1. import torch
  2. import numpy as np
  3. class MaxState(torch.nn.Module):
  4.     def __init__(self, hidden_dim, heads, win):
  5.         super(MaxState, self).__init__()
  6.         assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."
  7.         self.head_size = hidden_dim // heads
  8.         self.head = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
  9.         self.state = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
  10.         self.head_num = heads
  11.         self.win = win
  12.         self.hidden = hidden_dim
  13.         self.mask = torch.triu(torch.ones([win, win])).to("cuda")
  14.         self.layer_nor = torch.nn.LayerNorm(hidden_dim)
  15.     def forward(self, input_data, state=None):
  16.         # self.head.to("cuda")
  17.         b, s, k, h, w = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size, self.win
  18.         window = torch.ones([1, w]).to("cuda")
  19.         out = self.head(input_data)
  20.         out = out.unsqueeze(-1) @ window
  21.         out = out.permute([0, 2, 1, 3])
  22.         one_list = []
  23.         if state is None:
  24.             state = torch.ones([out.shape[0], out.shape[1], 1, 1]) * float("-inf")
  25.             state = state.to("cuda")
  26.         for i in range(0, s, w):
  27.             state.reshape([state.shape[0], -1])
  28.             j = w + i
  29.             one = out[:, :, i:j]
  30.             _, _, r, c = one.shape
  31.             if r != self.win:
  32.                 one = torch.where(self.mask[:r, :] == 1, one, torch.Tensor([-float('inf')]).to("cuda"))
  33.             else:
  34.                 one = torch.where(self.mask == 1, one, torch.Tensor([-float('inf')]).to("cuda"))
  35.             if i == 0:
  36.                 one = torch.concat([one, state @ window], axis=2)
  37.                 state, _ = torch.max(one, axis=2, keepdim=True)
  38.             else:
  39.                 state1, _ = torch.max(one, axis=2, keepdim=True)
  40.                 # state = torch.sin(self.state(state1.reshape([state1.shape[0], -1]))*state.reshape([state.shape[0], -1]))
  41.                 state1 = self.state(state1.permute([0, 3, 1, 2]).reshape([state1.shape[0], -1, state1.shape[1]]))
  42.                 state = state1.permute([0, 2, 1]).unsqueeze(-2) + state
  43.                 # state = state.reshape(state1.shape)
  44.                 one = torch.concat([one, state], axis=2)
  45.                 state, _ = torch.max(one, axis=2, keepdim=True)
  46.             one = state.reshape([b, k, h, w])
  47.             state = state[..., -1:]
  48.             if r != self.win:
  49.                 one = one[..., :r]
  50.             one = one.permute([0, 3, 1, 2])
  51.             one_list.append(one)
  52.         out = torch.concat(one_list, 1)
  53.         out = out.reshape([b, s, -1])
  54.         return out, state
  55. class FeedForward(torch.nn.Module):
  56.     def __init__(self, hidden_size):
  57.         super(FeedForward, self).__init__()
  58.         self.ffn1 = torch.nn.Linear(hidden_size, hidden_size * 2)
  59.         self.ffn2 = torch.nn.Linear(hidden_size * 2, hidden_size)
  60.         self.gate = torch.nn.Linear(hidden_size, hidden_size * 2)
  61.         self.relu = torch.nn.ReLU()
  62.     def forward(self, x):
  63.         x1 = self.ffn1(x)
  64.         x2 = self.relu(self.gate(x))
  65.         x = x1 * x2
  66.         x = self.ffn2(x)
  67.         return x
  68. class DecoderLayer(torch.nn.Module):
  69.     def __init__(self, hidden_size, num_heads):
  70.         super(DecoderLayer, self).__init__()
  71.         # self.self_attention = MaskMultiHeadAttention(hidden_size, num_heads)
  72.         self.self_attention = MaxState(hidden_size, num_heads, 8)
  73.         self.ffn = FeedForward(hidden_size)
  74.         self.layer_norm = torch.nn.LayerNorm(hidden_size)
  75.     def forward(self, x, state=None, seq_len=None):
  76.         x1, state = self.self_attention(x, state)
  77.         x = self.layer_norm(self.ffn(x1) + x)  # Feed-Forward with residual connection
  78.         return x, state
  79. class SamOut(torch.nn.Module):
  80.     def __init__(self, voc_size, hidden_size, num_heads, num_layers):
  81.         super(SamOut, self).__init__()
  82.         self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)
  83.         self.pos = torch.nn.Embedding(1024, hidden_size)
  84.         self.decoder_layers = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])
  85.         self.head = torch.nn.Linear(hidden_size, voc_size)
  86.         self.head_state = torch.nn.Linear(hidden_size, num_layers)
  87.         self.layer_nor=torch.nn.LayerNorm(hidden_size)
  88.         self.down=torch.nn.ModuleList([torch.nn.Linear(2*hidden_size,hidden_size) for _ in range(num_layers)])
  89.     def forward(self, x, state=None, seq_len=None):
  90.         x = self.em(x)
  91.         if x.shape[1] >= 1024:
  92.             pos = self.pos(torch.range(0, x.shape[1]-1).long() // 1024).unsqueeze(0)
  93.             pos = self.pos(torch.range(0, x.shape[1]-1).long() % 1024).unsqueeze(0) + pos
  94.         else:
  95.             pos = self.pos(torch.range(0, x.shape[1]-1).long().to("cuda")).unsqueeze(0)
  96.         if state is None:
  97.             state = [None] * len(self.decoder_layers)
  98.         i = 0
  99.         for decoder_layer in self.decoder_layers:
  100.             x1, state[i] = decoder_layer(self.down[i](torch.concat([torch.zeros([x.shape[0],1,1]).to("cuda")+pos , x],-1)), state[i])
  101.             x = x1 + x
  102.             i += 1
  103.         state_data = self.head_state((torch.concat(state, -1).squeeze(-2)).permute([0, 2, 1]))
  104.         return self.head(x), state, state_data
  105. if __name__ == '__main__':
  106.     net = SamOut(235, 256, 16, 4)
  107.     net(torch.randint(0, 200, [2, 3000]))
复制代码
解释

这段代码定义了一个基于PyTorch的神经网络模型,该模型包罗自定义的解码器层和输出层,用于处置惩罚序列数据。下面是代码的逐行解析:
  1. import torch
  2. import numpy as np
复制代码

  1. class MaxState(torch.nn.Module):
  2.     def __init__(self, hidden_dim, heads, win):
  3.         super(MaxState, self).__init__()
复制代码

  1.         assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."
复制代码

  1.         self.head_size = hidden_dim // heads
  2.         self.head = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
  3.         self.state = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
  4.         self.head_num = heads
  5.         self.win = win
  6.         self.hidden = hidden_dim
  7.         self.mask = torch.triu(torch.ones([win, win])).to("cuda")
  8.         self.layer_nor = torch.nn.LayerNorm(hidden_dim)
复制代码

  1.     def forward(self, input_data, state=None):
  2.         # self.head.to("cuda")
  3.         b, s, k, h, w = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size, self.win
复制代码

  1.         window = torch.ones([1, w]).to("cuda")
复制代码

  1.         out = self.head(input_data)
复制代码

  1.         out = out.unsqueeze(-1) @ window
复制代码

  1.         out = out.permute([0, 2, 1, 3])
复制代码

  1.         one_list = []
  2.         if state is None:
  3.             state = torch.ones([out.shape[0], out.shape[1], 1, 1]) * float("-inf")
  4.             state = state.to("cuda")
复制代码

  1.         for i in range(0, s, w):
  2.             # ... (省略中间代码)
复制代码

  1.         return out, state
复制代码

  1. if __name__ == '__main__':
  2.     net = SamOut(235, 256, 16, 4)
  3.     net(torch.randint(0, 200, [2, 3000]))
复制代码


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 ToB企服应用市场:ToB评测及商务社交产业平台 (https://dis.qidao123.com/) Powered by Discuz! X3.4