IT评测·应用市场-qidao123.com技术社区
标题:
六、分布式嵌入
[打印本页]
作者:
天津储鑫盛钢材现货供应商
时间:
2025-4-18 15:28
标题:
六、分布式嵌入
六、分布式嵌入
前言
我们已经利用了
TorchRec
的主模块:
EmbeddedBagCollection
。我们在上一节研究了它是如何工作的,以及数据在
TorchRec
中是如何表示的。然而,我们还没有探索
TorchRec
的重要部分之一,即
分布式嵌入
一、先要设置torch.distributed环境
EmbeddingBagCollectionSharder
依赖于
PyTorch
的分布式通讯库(
torch.distributed
)来管理跨进程/GPU 的分片和通讯。
首先初始化分布式环境
import torch.distributed as dist
# 初始化进程组
dist.init_process_group(
backend="nccl", # GPU 推荐 NCCL 后端, CPU就是 gloo
init_method="env://", # 从环境变量读取节点信息
rank=rank, # 当前进程的全局唯一标识(从 0 开始)
world_size=world_size, # 总进程数(总 GPU 数)
)
pg = dist.GroupMember.WORLD
复制代码
设置环境变量(多节点训练时必须)
import torch.distributed as dist
# 初始化进程组
# 在每个节点上设置以下环境变量
export MASTER_ADDR="主节点IP" # 如 "192.168.1.1"
export MASTER_PORT="66666" # 任意未占用端口
export WORLD_SIZE=4 # 总 GPU 数
export RANK=0 # 当前节点的全局 rank
复制代码
二、Distributed Embeddings
先回顾一下我们上一节的
EmbeddingBagCollection module
代码演示:
print(ebc)
"""
EmbeddingBagCollection(
(embedding_bags): ModuleDict(
(product_table): EmbeddingBag(4096, 64, mode='sum')
(user_table): EmbeddingBag(4096, 64, mode='sum')
)
)
"""
复制代码
2.1 EmbeddingBagCollectionSharder
策略制定者 ,决定如何分片。
决定如何将
EmbeddingBagCollection
的嵌入表(
Embedding Tables
)分布到多个 GPU/节点。
核心功能 :根据设置(如
ShardingType
)天生分片计划(
Sharding Plan
)
代码演示:
from torchrec.distributed.embedding_types import ShardingType
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
# 定义分片器:指定分片策略(如按表分片)
sharder = EmbeddingBagCollectionSharder(
sharding_type=ShardingType.TABLE_WISE.value, # 每个表分配到一个 GPU
kernel_type=EmbeddingComputeKernel.FUSED.value, # 使用 fused 优化
)
复制代码
关键参数
sharding_type
:分片策略,如:
TABLE_WISE:整个表放在一个 GPU。
ROW_WISE:按行分片到多个 GPU。
COLUMN_WISE:按列分片(适用于超大表)。
kernel_type
:盘算内核范例(如 FUSED 优化显存)
2.2 ShardedEmbeddingBagCollection
策略执行者 ,实际管理分片后的嵌入表
根据
EmbeddingBagCollectionSharder
天生的分片计划,实际管理分布在多设备上的嵌入表。
核心功能 :在分布式环境中执行前向传播、梯度聚合和参数更新
代码演示:
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
# 根据分片器生成分片后的模块
sharded_ebc = ShardedEmbeddingBagCollection(
module=ebc, # 原始 EmbeddingBagCollection
sharder=sharder, # 分片策略
device=device, # 目标设备(如 GPU:0)
)
复制代码
三、Planner
它可以帮助我们确定最佳的分片设置。
Planner
可以大概根据嵌入表的数量和GPU的数量来确定最佳设置。事实证明,这很难手动完成,工程师必须考虑大量因素来确保最佳的分片计划。
TorchRec
在提供的这个
Planner
,可以帮助我们:
评估硬件的内存限制
将基于存储器获取的盘算估计为嵌入查找
解决数据特定因素
考虑其他硬件细节,如带宽,以天生最佳分片计划
演示代码:
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
# 初始化Planner
planner = EmbeddingShardingPlanner(
topology=Topology( # 硬件拓扑信息
world_size=4, # 总 GPU 数
compute_device="cuda",
local_world_size=2, # 单机 GPU 数
batch_size=1024,
),
constraints={ # 可选约束(如强制某些表使用特定策略)
"user_id": ParameterConstraints(sharding_types=[ShardingType.TABLE_WISE]),
},
)
# 生成分片计划
plan = planner.collective_plan(ebc, [sharder], pg)
# 分片后的模型
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
sharded_ebc = ShardedEmbeddingBagCollection(
module=ebc,
sharder=sharder,
device=torch.device("cuda:0"),
plan=plan, # 应用自动生成的分片计划
)
复制代码
总结
TorchRec中的分布式嵌入以及训练设置。
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
欢迎光临 IT评测·应用市场-qidao123.com技术社区 (https://dis.qidao123.com/)
Powered by Discuz! X3.4