PyTorch 中contiguous函数使用详解和代码演示

[复制链接]
发表于 2025-6-8 07:08:33 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

×
在 PyTorch 中,contiguous() 是一个用于 张量内存结构优化 的函数。它的作用是在必要时返回一个内存结构为连续(contiguous)的张量,常用于 transpose、permute 等操作后。

一、为什么必要 contiguous()

PyTorch 的张量是以 行优先(row-major)顺序 存储的。当你对张量使用 transpose()、permute() 等操作时,固然张量的维度看起来改变了,但底层的内存并没有重新排列,只是修改了索引方式。
   一些 PyTorch 函数(如 .view())要求输入张量必须是 连续的内存块,否则就会报错。
  
二、函数定义与用法

  1. Tensor.contiguous(memory_format=torch.contiguous_format) → Tensor
复制代码
返回值:

返回一个与当前张量具有相同数据但在内存中 连续排列 的副本。如果当前张量已经是连续的,就直接返回自身。

三、典型使用场景

1. view() 前必要 .contiguous()

  1. x = torch.randn(2, 3, 4)
  2. y = x.permute(1, 0, 2)  # 改变维度顺序
  3. z = y.contiguous().view(3, 8)  # 安全 reshape
复制代码
如果不加 .contiguous():
  1. z = y.view(3, 8)  # ⚠️ 报错:RuntimeError: view size is not compatible with input tensor's size and stride
复制代码

2. 使用 transpose() 后必要 .contiguous() 加入后续操作

  1. a = torch.randn(10, 20)
  2. b = a.transpose(0, 1)  # Not contiguous now
  3. b = b.contiguous()     # 重新在内存中复制数据为连续块
复制代码

四、查看是否是连续的

  1. x.is_contiguous()
复制代码

五、底层原理简要

PyTorch 张量有 .stride() 属性定义每一维的跳步。连续的张量满意:
  1. x.stride()[i] = product(x.shape[i+1:])
复制代码
一旦 .transpose() / .permute() 修改了维度顺序,这个规则就被破坏,因此 .contiguous() 会重新分配内存来确保是连续的。

六、contiguous()项目演示

下面是一个完备的 PyTorch 小项目,演示 .contiguous() 的须要性与作用。将看到在对张量进行 permute() 后,使用 .view() reshape 会失败,只有 .contiguous() 可以解决问题。
项目内容:张量维度变更与 .contiguous() 对比演示
项目结构:

  1. contiguous_demo/
  2. ├── main.py
  3. └── requirements.txt
复制代码

requirements.txt

  1. torch>=2.0
复制代码

main.py

  1. import torch
  2. def describe_tensor(tensor, name):
  3.     print(f"{name}: shape={tensor.shape}, strides={tensor.stride()}, is_contiguous={tensor.is_contiguous()}")
  4. def main():
  5.     print("=== 创建张量 ===")
  6.     x = torch.randn(2, 3, 4)  # 原始张量 shape [2, 3, 4]
  7.     describe_tensor(x, "x")
  8.     print("\n=== 进行 permute 操作(交换维度) ===")
  9.     y = x.permute(1, 0, 2)  # shape: [3, 2, 4]
  10.     describe_tensor(y, "y (after permute)")
  11.     print("\n尝试 view reshape 到 [3, 8](不使用 contiguous)")
  12.     try:
  13.         z = y.view(3, 8)  # ⚠️ 报错:因为 y 的内存不是连续的
  14.     except RuntimeError as e:
  15.         print(f"RuntimeError: {e}")
  16.     print("\n=== 使用 .contiguous() 后 reshape ===")
  17.     y_contig = y.contiguous()
  18.     describe_tensor(y_contig, "y_contig (after .contiguous())")
  19.     z = y_contig.view(3, 8)
  20.     describe_tensor(z, "z (reshaped)")
  21.     print("\n✅ reshape 成功,结果如下:")
  22.     print(z)
  23. if __name__ == "__main__":
  24.     main()
复制代码

运行方法


  • 安装依赖:
  1. pip install -r requirements.txt
复制代码

  • 运行步调:
  1. python main.py
复制代码

运行效果概览

将看到:


  • 原始张量是连续的;
  • permute() 后变成非连续;
  • 使用 .view() 报错;
  • .contiguous() 修复内存后成功 reshape。

小总结要点

操作是否连续能否 .view()原始张量✅ 是✅ 是permute() 后❌ 否❌ 报错.contiguous() 后✅ 是✅ 是
总结记忆:

操作是否影响连续性?是否必要 .contiguous()view()❗ 必要连续✅ 是permute() / transpose()破坏连续性✅ 是reshape()自动处理❌ 不必要(内部处理)
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
继续阅读请点击广告
回复

使用道具 举报

快速回复 返回顶部 返回列表