小白也能读懂的ConvLSTM!(开源pytorch代码)

打印 上一主题 下一主题

主题 523|帖子 523|积分 1569

仅需要网络源码的可以直接跳到末了即可
1. 算法简介与应用场景

ConvLSTM(卷积是非期影象网络)是一种结合了卷积神经网络(CNN)和是非期影象网络(LSTM)上风的深度学习模子。它紧张用于处置惩罚时空数据,特殊适用于需要考虑空间特征和时间依赖关系的任务,如气象猜测、视频分析、交通流量猜测等。
在气象猜测中,ConvLSTM可以根据过去的气象数据(如降水、温度等)猜测未来的气候情况。在视频分析中,它可以帮助辨认视频中的活动或事件,利用时间序列的连续性和空间信息进行更准确的分析。
2. 算法原理

2.1 LSTM底子

在先容ConvLSTM之前,先让我们来回归一下什么是是非期影象网络(LSTM)。LSTM是一种特殊的循环神经网络(RNN),它通过引入门控机制办理了传统RNN在长序列练习中面临的梯度消散和爆炸题目。LSTM单元紧张包罗三个门:输入门、忘记门和输出门。这些门控制着信息在单元中的流动,从而有效地记住或忘记信息。
LSTM的焦点公式如下:


  • 忘记门
                                                                    f                                     t                                              =                                  σ                                  (                                               W                                     f                                              ⋅                                  [                                               h                                                   t                                        −                                        1                                                           ,                                               x                                     t                                              ]                                  +                                               b                                     f                                              )                                          f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)                           ft​=σ(Wf​⋅[ht−1​,xt​]+bf​)
  • 输入门
                                                                    i                                     t                                              =                                  σ                                  (                                               W                                     i                                              ⋅                                  [                                               h                                                   t                                        −                                        1                                                           ,                                               x                                     t                                              ]                                  +                                               b                                     i                                              )                                          i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)                           it​=σ(Wi​⋅[ht−1​,xt​]+bi​)
                                                                                  C                                        ~                                                  t                                              =                                  tanh                                  ⁡                                  (                                               W                                     C                                              ⋅                                  [                                               h                                                   t                                        −                                        1                                                           ,                                               x                                     t                                              ]                                  +                                               b                                     C                                              )                                          \tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C)                           C~t​=tanh(WC​⋅[ht−1​,xt​]+bC​)
  • 单元状态更新
                                                                    C                                     t                                              =                                               f                                     t                                              ∗                                               C                                                   t                                        −                                        1                                                           +                                               i                                     t                                              ∗                                                             C                                        ~                                                  t                                                      C_t = f_t \ast C_{t-1} + i_t \ast \tilde{C}_t                           Ct​=ft​∗Ct−1​+it​∗C~t​
  • 输出门
                                                                    o                                     t                                              =                                  σ                                  (                                               W                                     o                                              ⋅                                  [                                               h                                                   t                                        −                                        1                                                           ,                                               x                                     t                                              ]                                  +                                               b                                     o                                              )                                          o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)                           ot​=σ(Wo​⋅[ht−1​,xt​]+bo​)
                                                                    h                                     t                                              =                                               o                                     t                                              ∗                                  tanh                                  ⁡                                  (                                               C                                     t                                              )                                          h_t = o_t \ast \tanh(C_t)                           ht​=ot​∗tanh(Ct​)
这里,                                             C                            t                                       C_t                  Ct​ 是当前的单元状态,                                             h                            t                                       h_t                  ht​ 是当前的隐藏状态,                                             x                            t                                       x_t                  xt​ 是当前的输入。
2.2 ConvLSTM原理

ConvLSTM在LSTM的底子上引入了卷积操作。传统的LSTM使用全连接层处置惩罚输入数据,而ConvLSTM则采用卷积层来处置惩罚空间数据。这样,ConvLSTM能够更好地捕捉输入数据中的空间特征。

2.2.1 ConvLSTM的布局

ConvLSTM的单元布局与LSTM非常相似,但是在每个门的盘算中使用了卷积操作。具体来说,ConvLSTM的每个门的公式可以表现为:
                                                    i                               t                                      =                            σ                            (                                       W                                           x                                  i                                                 ∗                                       X                               t                                      +                                       W                                           h                                  i                                                 ∗                                       H                                           t                                  −                                  1                                                 +                                       W                                           c                                  i                                                 ∘                                       C                                           t                                  −                                  1                                                 +                                       b                               i                                      )                                  i_t = \sigma (W_{xi} * X_t + W_{hi} * H_{t-1} + W_{ci} \circ C_{t-1} + b_i)                     it​=σ(Wxi​∗Xt​+Whi​∗Ht−1​+Wci​∘Ct−1​+bi​)
                                                    f                               t                                      =                            σ                            (                                       W                                           x                                  f                                                 ∗                                       X                               t                                      +                                       W                                           h                                  f                                                 ∗                                       H                                           t                                  −                                  1                                                 +                                       W                                           c                                  f                                                 ∘                                       C                                           t                                  −                                  1                                                 +                                       b                               f                                      )                                  f_t = \sigma (W_{xf} * X_t + W_{hf} * H_{t-1} + W_{cf} \circ C_{t-1} + b_f)                     ft​=σ(Wxf​∗Xt​+Whf​∗Ht−1​+Wcf​∘Ct−1​+bf​)
                                                    C                               t                                      =                                       f                               t                                      ∘                                       C                                           t                                  −                                  1                                                 +                                       i                               t                                      ∘                            t                            a                            n                            h                            (                                       W                                           x                                  c                                                 ∗                                       X                               t                                      +                                       W                                           h                                  c                                                 ∗                                       H                                           t                                  −                                  1                                                 +                                       b                               c                                      )                                  C_t = f_t \circ C_{t-1} + i_t \circ tanh(W_{xc} * X_t + W_{hc} * H_{t-1} + b_c)                     Ct​=ft​∘Ct−1​+it​∘tanh(Wxc​∗Xt​+Whc​∗Ht−1​+bc​)
                                                    o                               t                                      =                            σ                            (                                       W                                           x                                  o                                                 ∗                                       X                               t                                      +                                       W                                           h                                  o                                                 ∗                                       H                                           t                                  −                                  1                                                 +                                       W                                           c                                  o                                                 ∘                                       C                               t                                      +                                       b                               o                                      )                                  o_t = \sigma (W_{xo} * X_t + W_{ho} * H_{t-1} + W_{co} \circ C_t + b_o)                     ot​=σ(Wxo​∗Xt​+Who​∗Ht−1​+Wco​∘Ct​+bo​)
                                                    H                               t                                      =                                       o                               t                                      ∘                            t                            a                            n                            h                            (                                       C                               t                                      )                                  H_t = o_t \circ tanh(C_t)                     Ht​=ot​∘tanh(Ct​)
这里的 全部                                   W                              W                  W都是是卷积权重,                                   b                              b                  b是偏置项,                                   σ                              \sigma                  σ 是 sigmoid 函数,                                   tanh                         ⁡                              \tanh                  tanh 是双曲正切函数。。

2.2.2 卷积操作的长处


  • 空间特征提取:卷积操作能够有效提取输入数据中的空间特征。对于图像数据,卷积操作可以捕捉局部特征,例如边缘、纹理等,这在时间序列数据中同样适用。
  • 参数共享:卷积操作通过使用相同的卷积核在差异位置盘算特征,从而淘汰了模子参数的数量,降低了盘算复杂度。
  • 平移稳定性:卷积网络对输入数据的平移具有稳定性,即相同的特征在差异位置都会被检测到,这对于时空序列数据来说是非常紧张的。
2.3 LSTM与ConvLSTM的对比分析

特性LSTMConvLSTM输入类型一维序列三维数据(时序的图像数据)处置惩罚方式全连接层卷积操作空间特征捕捉较弱较强应用场景自然语言处置惩罚、时间序列猜测图像序列猜测、视频分析 2.4 ConvLSTM的应用

ConvLSTM在多个领域中表现出色,特殊得当处置惩罚具有时空特征的数据。以下是一些紧张的应用场景:


  • 气象猜测:利用历史气象数据(如温度、湿度、降水等)来猜测未来的气候情况。
  • 视频分析:对视频中的动态场景进行建模,辨认和猜测视频中的活动。
  • 交通流量猜测:基于历史交通数据猜测未来的交通流量,帮助城市交通管理。
  • 医学影像分析:分析医学影像序列(如CT、MRI)中的变化,辅助疾病诊断。
3. PyTorch代码

以下是ConvLSTM的完备代码,可以直接拿来用:
  1. import torch.nn as nn
  2. import torch
  3. class ConvLSTMCell(nn.Module):
  4.     def __init__(self, input_dim, hidden_dim, kernel_size, bias):
  5.         """
  6.         初始化卷积 LSTM 单元。
  7.         参数:
  8.         ----------
  9.         input_dim: int
  10.             输入张量的通道数。
  11.         hidden_dim: int
  12.             隐藏状态的通道数。
  13.         kernel_size: (int, int)
  14.             卷积核的大小。
  15.         bias: bool
  16.             是否添加偏置项。
  17.         """
  18.         super(ConvLSTMCell, self).__init__()
  19.         self.input_dim = input_dim
  20.         self.hidden_dim = hidden_dim
  21.         self.kernel_size = kernel_size
  22.         # 计算填充大小以保持输入和输出尺寸一致
  23.         self.padding = kernel_size[0] // 2, kernel_size[1] // 2
  24.         self.bias = bias
  25.         # 定义卷积层,输入是输入维度加上隐藏维度,输出是4倍的隐藏维度(对应i, f, o, g)
  26.         self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
  27.                               out_channels=4 * self.hidden_dim,
  28.                               kernel_size=self.kernel_size,
  29.                               padding=self.padding,
  30.                               bias=self.bias)
  31.     def forward(self, input_tensor, cur_state):
  32.         h_cur, c_cur = cur_state
  33.         # 沿着通道轴进行拼接
  34.         combined = torch.cat([input_tensor, h_cur], dim=1)
  35.         combined_conv = self.conv(combined)
  36.         # 将输出分割成四个部分,分别对应输入门、遗忘门、输出门和候选单元状态
  37.         cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
  38.         i = torch.sigmoid(cc_i)
  39.         f = torch.sigmoid(cc_f)
  40.         o = torch.sigmoid(cc_o)
  41.         g = torch.tanh(cc_g)
  42.         # 更新单元状态
  43.         c_next = f * c_cur + i * g
  44.         # 更新隐藏状态
  45.         h_next = o * torch.tanh(c_next)
  46.         return h_next, c_next
  47.     def init_hidden(self, batch_size, image_size):
  48.         height, width = image_size
  49.         # 初始化隐藏状态和单元状态为零
  50.         return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
  51.                 torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
  52. class ConvLSTM(nn.Module):
  53.     """
  54.     卷积 LSTM 层。
  55.     参数:
  56.     ----------
  57.     input_dim: 输入通道数
  58.     hidden_dim: 隐藏通道数
  59.     kernel_size: 卷积核大小
  60.     num_layers: LSTM 层的数量
  61.     batch_first: 批次是否在第一维
  62.     bias: 卷积中是否有偏置项
  63.     return_all_layers: 是否返回所有层的计算结果
  64.     输入:
  65.     ------
  66.     一个形状为 B, T, C, H, W 或者 T, B, C, H, W 的张量
  67.     输出:
  68.     ------
  69.     元组包含两个列表(长度为 num_layers 或者长度为 1 如果 return_all_layers 为 False):
  70.     0 - layer_output_list 是长度为 T 的每个输出的列表
  71.     1 - last_state_list 是最后的状态列表,其中每个元素是一个 (h, c) 对应隐藏状态和记忆状态
  72.     示例:
  73.     >>> x = torch.rand((32, 10, 64, 128, 128))
  74.     >>> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
  75.     >>> _, last_states = convlstm(x)
  76.     >>> h = last_states[0][0]  # 0 表示层索引,0 表示 h 索引
  77.     """
  78.     def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
  79.                  batch_first=False, bias=True, return_all_layers=False):
  80.         super(ConvLSTM, self).__init__()
  81.         # 检查 kernel_size 的一致性
  82.         self._check_kernel_size_consistency(kernel_size)
  83.         # 确保 kernel_size 和 hidden_dim 的长度与层数一致
  84.         kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
  85.         hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
  86.         if not len(kernel_size) == len(hidden_dim) == num_layers:
  87.             raise ValueError('不一致的列表长度。')
  88.         self.input_dim = input_dim
  89.         self.hidden_dim = hidden_dim
  90.         self.kernel_size = kernel_size
  91.         self.num_layers = num_layers
  92.         self.batch_first = batch_first
  93.         self.bias = bias
  94.         self.return_all_layers = return_all_layers
  95.         # 创建 ConvLSTMCell 列表
  96.         cell_list = []
  97.         for i in range(0, self.num_layers):
  98.             cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
  99.             cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
  100.                                           hidden_dim=self.hidden_dim[i],
  101.                                           kernel_size=self.kernel_size[i],
  102.                                           bias=self.bias))
  103.         self.cell_list = nn.ModuleList(cell_list)
  104.     def forward(self, input_tensor, hidden_state=None):
  105.         """
  106.         前向传播函数。
  107.         参数:
  108.         ----------
  109.         input_tensor: 输入张量,形状为 (t, b, c, h, w) 或者 (b, t, c, h, w)
  110.         hidden_state: 初始隐藏状态,默认为 None
  111.         返回:
  112.         -------
  113.         last_state_list, layer_output
  114.         """
  115.         if not self.batch_first:
  116.             # 改变输入张量的顺序,如果 batch_first 为 False
  117.             input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
  118.         b, _, _, h, w = input_tensor.size()
  119.         # 实现状态化的 ConvLSTM
  120.         if hidden_state is not None:
  121.             raise NotImplementedError()
  122.         else:
  123.             # 初始化隐藏状态
  124.             hidden_state = self._init_hidden(batch_size=b,
  125.                                              image_size=(h, w))
  126.         layer_output_list = []
  127.         last_state_list = []
  128.         seq_len = input_tensor.size(1)
  129.         cur_layer_input = input_tensor
  130.         for layer_idx in range(self.num_layers):
  131.             h, c = hidden_state[layer_idx]
  132.             output_inner = []
  133.             for t in range(seq_len):
  134.                 # 在每个时间步上更新状态
  135.                 h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
  136.                                                  cur_state=[h, c])
  137.                 output_inner.append(h)
  138.             # 将输出堆叠起来
  139.             layer_output = torch.stack(output_inner, dim=1)
  140.             cur_layer_input = layer_output
  141.             layer_output_list.append(layer_output)
  142.             last_state_list.append([h, c])
  143.         if not self.return_all_layers:
  144.             # 如果不需要返回所有层,则只返回最后一层的输出和状态
  145.             layer_output_list = layer_output_list[-1:]
  146.             last_state_list = last_state_list[-1:]
  147.         return layer_output_list, last_state_list
  148.     def _init_hidden(self, batch_size, image_size):
  149.         init_states = []
  150.         for i in range(self.num_layers):
  151.             # 初始化每一层的隐藏状态
  152.             init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
  153.         return init_states
  154.     @staticmethod
  155.     def _check_kernel_size_consistency(kernel_size):
  156.         if not (isinstance(kernel_size, tuple) or
  157.                 (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
  158.             raise ValueError('`kernel_size` 必须是 tuple 或者 list of tuples')
  159.     @staticmethod
  160.     def _extend_for_multilayer(param, num_layers):
  161.         if not isinstance(param, list):
  162.             param = [param] * num_layers
  163.         return param
复制代码
参考文献

  1. [1]Shi, X., Chen, Z., Wang, H., Yeung, D. Y., Wong, W. K., & Woo, W. (2015). Convolutional LSTM Network: A Machine Learning [2]Approach for Precipitation Nowcasting. Advances in Neural Information Processing Systems, 28.
  2. [2]Hochreiter, S., & Schmidhuber, J. (1997). Long Short-Term Memory. Neural Computation, 9(8), 1735-1780.
  3. [3]Goodfellow, I., Bengio, Y., & Courville, A. (2016). Deep Learning. MIT Press.
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

九天猎人

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表