IT评测·应用市场-qidao123.com技术社区

标题: 0底子跟德姆(dom)一起学AI 天然语言处理10-LSTM模型 [打印本页]

作者: 尚未崩坏    时间: 2025-1-2 18:07
标题: 0底子跟德姆(dom)一起学AI 天然语言处理10-LSTM模型
1 LSTM先容

LSTM(Long Short-Term Memory)也称长短时影象结构, 它是传统RNN的变体, 与经典RNN相比能够有效捕获长序列之间的语义关联, 缓解梯度消失或爆炸征象. 同时LSTM的结构更复杂, 它的焦点结构可以分为四个部门去解析:

2 LSTM的内部结构图

2.1 LSTM结构分析






















2.2 Bi-LSTM先容

Bi-LSTM即双向LSTM, 它没有改变LSTM自己任何的内部结构, 只是将LSTM应用两次且方向差别, 再将两次得到的LSTM效果进行拼接作为最终输出.


2.3 使用Pytorch构建LSTM模型


  1. # 定义LSTM的参数含义: (input_size, hidden_size, num_layers)
  2. # 定义输入张量的参数含义: (sequence_length, batch_size, input_size)
  3. # 定义隐藏层初始张量和细胞初始状态张量的参数含义:
  4. # (num_layers * num_directions, batch_size, hidden_size)
  5. >>> import torch.nn as nn
  6. >>> import torch
  7. >>> rnn = nn.LSTM(5, 6, 2)
  8. >>> input = torch.randn(1, 3, 5)
  9. >>> h0 = torch.randn(2, 3, 6)
  10. >>> c0 = torch.randn(2, 3, 6)
  11. >>> output, (hn, cn) = rnn(input, (h0, c0))
  12. >>> output
  13. tensor([[[ 0.0447, -0.0335,  0.1454,  0.0438,  0.0865,  0.0416],
  14.          [ 0.0105,  0.1923,  0.5507, -0.1742,  0.1569, -0.0548],
  15.          [-0.1186,  0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]],
  16.        grad_fn=<StackBackward>)
  17. >>> hn
  18. tensor([[[ 0.4647, -0.2364,  0.0645, -0.3996, -0.0500, -0.0152],
  19.          [ 0.3852,  0.0704,  0.2103, -0.2524,  0.0243,  0.0477],
  20.          [ 0.2571,  0.0608,  0.2322,  0.1815, -0.0513, -0.0291]],
  21.         [[ 0.0447, -0.0335,  0.1454,  0.0438,  0.0865,  0.0416],
  22.          [ 0.0105,  0.1923,  0.5507, -0.1742,  0.1569, -0.0548],
  23.          [-0.1186,  0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]],
  24.        grad_fn=<StackBackward>)
  25. >>> cn
  26. tensor([[[ 0.8083, -0.5500,  0.1009, -0.5806, -0.0668, -0.1161],
  27.          [ 0.7438,  0.0957,  0.5509, -0.7725,  0.0824,  0.0626],
  28.          [ 0.3131,  0.0920,  0.8359,  0.9187, -0.4826, -0.0717]],
  29.         [[ 0.1240, -0.0526,  0.3035,  0.1099,  0.5915,  0.0828],
  30.          [ 0.0203,  0.8367,  0.9832, -0.4454,  0.3917, -0.1983],
  31.          [-0.2976,  0.7764, -0.0074, -0.1965, -0.1343, -0.6683]]],
  32.        grad_fn=<StackBackward>)
复制代码
2.4 LSTM优缺点



免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 IT评测·应用市场-qidao123.com技术社区 (https://dis.qidao123.com/) Powered by Discuz! X3.4