马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。
您需要 登录 才可以下载或查看,没有账号?立即注册
×
在 PyTorch 中,contiguous() 是一个用于 张量内存结构优化 的函数。它的作用是在必要时返回一个内存结构为连续(contiguous)的张量,常用于 transpose、permute 等操作后。
一、为什么必要 contiguous()
PyTorch 的张量是以 行优先(row-major)顺序 存储的。当你对张量使用 transpose()、permute() 等操作时,固然张量的维度看起来改变了,但底层的内存并没有重新排列,只是修改了索引方式。
一些 PyTorch 函数(如 .view())要求输入张量必须是 连续的内存块,否则就会报错。
二、函数定义与用法
- Tensor.contiguous(memory_format=torch.contiguous_format) → Tensor
复制代码 返回值:
返回一个与当前张量具有相同数据但在内存中 连续排列 的副本。如果当前张量已经是连续的,就直接返回自身。
三、典型使用场景
1. view() 前必要 .contiguous()
- x = torch.randn(2, 3, 4)
- y = x.permute(1, 0, 2) # 改变维度顺序
- z = y.contiguous().view(3, 8) # 安全 reshape
复制代码 如果不加 .contiguous():
- z = y.view(3, 8) # ⚠️ 报错:RuntimeError: view size is not compatible with input tensor's size and stride
复制代码 2. 使用 transpose() 后必要 .contiguous() 加入后续操作
- a = torch.randn(10, 20)
- b = a.transpose(0, 1) # Not contiguous now
- b = b.contiguous() # 重新在内存中复制数据为连续块
复制代码 四、查看是否是连续的
五、底层原理简要
PyTorch 张量有 .stride() 属性定义每一维的跳步。连续的张量满意:
- x.stride()[i] = product(x.shape[i+1:])
复制代码 一旦 .transpose() / .permute() 修改了维度顺序,这个规则就被破坏,因此 .contiguous() 会重新分配内存来确保是连续的。
六、contiguous()项目演示
下面是一个完备的 PyTorch 小项目,演示 .contiguous() 的须要性与作用。将看到在对张量进行 permute() 后,使用 .view() reshape 会失败,只有 .contiguous() 可以解决问题。
项目内容:张量维度变更与 .contiguous() 对比演示
项目结构:
- contiguous_demo/
- ├── main.py
- └── requirements.txt
复制代码 requirements.txt
main.py
- import torch
- def describe_tensor(tensor, name):
- print(f"{name}: shape={tensor.shape}, strides={tensor.stride()}, is_contiguous={tensor.is_contiguous()}")
- def main():
- print("=== 创建张量 ===")
- x = torch.randn(2, 3, 4) # 原始张量 shape [2, 3, 4]
- describe_tensor(x, "x")
- print("\n=== 进行 permute 操作(交换维度) ===")
- y = x.permute(1, 0, 2) # shape: [3, 2, 4]
- describe_tensor(y, "y (after permute)")
- print("\n尝试 view reshape 到 [3, 8](不使用 contiguous)")
- try:
- z = y.view(3, 8) # ⚠️ 报错:因为 y 的内存不是连续的
- except RuntimeError as e:
- print(f"RuntimeError: {e}")
- print("\n=== 使用 .contiguous() 后 reshape ===")
- y_contig = y.contiguous()
- describe_tensor(y_contig, "y_contig (after .contiguous())")
- z = y_contig.view(3, 8)
- describe_tensor(z, "z (reshaped)")
- print("\n✅ reshape 成功,结果如下:")
- print(z)
- if __name__ == "__main__":
- main()
复制代码 运行方法
- pip install -r requirements.txt
复制代码 运行效果概览
将看到:
- 原始张量是连续的;
- permute() 后变成非连续;
- 使用 .view() 报错;
- .contiguous() 修复内存后成功 reshape。
小总结要点
操作是否连续能否 .view()原始张量✅ 是✅ 是permute() 后❌ 否❌ 报错.contiguous() 后✅ 是✅ 是 总结记忆:
操作是否影响连续性?是否必要 .contiguous()view()❗ 必要连续✅ 是permute() / transpose()破坏连续性✅ 是reshape()自动处理❌ 不必要(内部处理)
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
|