马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。
您需要 登录 才可以下载或查看,没有账号?立即注册
x
六、分布式嵌入
前言
- 我们已经利用了TorchRec的主模块:EmbeddedBagCollection。我们在上一节研究了它是如何工作的,以及数据在TorchRec中是如何表示的。然而,我们还没有探索TorchRec的重要部分之一,即分布式嵌入
一、先要设置torch.distributed环境
- EmbeddingBagCollectionSharder 依赖于 PyTorch 的分布式通讯库(torch.distributed)来管理跨进程/GPU 的分片和通讯。
首先初始化分布式环境
- import torch.distributed as dist
- # 初始化进程组
- dist.init_process_group(
- backend="nccl", # GPU 推荐 NCCL 后端, CPU就是 gloo
- init_method="env://", # 从环境变量读取节点信息
- rank=rank, # 当前进程的全局唯一标识(从 0 开始)
- world_size=world_size, # 总进程数(总 GPU 数)
- )
- pg = dist.GroupMember.WORLD
复制代码 设置环境变量(多节点训练时必须)
- import torch.distributed as dist
- # 初始化进程组
- # 在每个节点上设置以下环境变量
- export MASTER_ADDR="主节点IP" # 如 "192.168.1.1"
- export MASTER_PORT="66666" # 任意未占用端口
- export WORLD_SIZE=4 # 总 GPU 数
- export RANK=0 # 当前节点的全局 rank
复制代码 二、Distributed Embeddings
- 先回顾一下我们上一节的EmbeddingBagCollection module
代码演示:
- print(ebc)
- """
- EmbeddingBagCollection(
- (embedding_bags): ModuleDict(
- (product_table): EmbeddingBag(4096, 64, mode='sum')
- (user_table): EmbeddingBag(4096, 64, mode='sum')
- )
- )
- """
复制代码 2.1 EmbeddingBagCollectionSharder
- 策略制定者 ,决定如何分片。
- 决定如何将 EmbeddingBagCollection 的嵌入表(Embedding Tables)分布到多个 GPU/节点。
核心功能 :根据设置(如 ShardingType)天生分片计划(Sharding Plan)
代码演示:
- from torchrec.distributed.embedding_types import ShardingType
- from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
- # 定义分片器:指定分片策略(如按表分片)
- sharder = EmbeddingBagCollectionSharder(
- sharding_type=ShardingType.TABLE_WISE.value, # 每个表分配到一个 GPU
- kernel_type=EmbeddingComputeKernel.FUSED.value, # 使用 fused 优化
- )
复制代码
- 关键参数
- sharding_type:分片策略,如:
- TABLE_WISE:整个表放在一个 GPU。
- ROW_WISE:按行分片到多个 GPU。
- COLUMN_WISE:按列分片(适用于超大表)。
- kernel_type:盘算内核范例(如 FUSED 优化显存)
2.2 ShardedEmbeddingBagCollection
- 策略执行者 ,实际管理分片后的嵌入表
- 根据 EmbeddingBagCollectionSharder 天生的分片计划,实际管理分布在多设备上的嵌入表。
- 核心功能 :在分布式环境中执行前向传播、梯度聚合和参数更新
代码演示:
- from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
- # 根据分片器生成分片后的模块
- sharded_ebc = ShardedEmbeddingBagCollection(
- module=ebc, # 原始 EmbeddingBagCollection
- sharder=sharder, # 分片策略
- device=device, # 目标设备(如 GPU:0)
- )
复制代码 三、Planner
- 它可以帮助我们确定最佳的分片设置。
- Planner可以大概根据嵌入表的数量和GPU的数量来确定最佳设置。事实证明,这很难手动完成,工程师必须考虑大量因素来确保最佳的分片计划。
- TorchRec在提供的这个Planner,可以帮助我们:
- 评估硬件的内存限制
- 将基于存储器获取的盘算估计为嵌入查找
- 解决数据特定因素
- 考虑其他硬件细节,如带宽,以天生最佳分片计划
演示代码:
- from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
- # 初始化Planner
- planner = EmbeddingShardingPlanner(
- topology=Topology( # 硬件拓扑信息
- world_size=4, # 总 GPU 数
- compute_device="cuda",
- local_world_size=2, # 单机 GPU 数
- batch_size=1024,
- ),
- constraints={ # 可选约束(如强制某些表使用特定策略)
- "user_id": ParameterConstraints(sharding_types=[ShardingType.TABLE_WISE]),
- },
- )
- # 生成分片计划
- plan = planner.collective_plan(ebc, [sharder], pg)
- # 分片后的模型
- from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
- sharded_ebc = ShardedEmbeddingBagCollection(
- module=ebc,
- sharder=sharder,
- device=torch.device("cuda:0"),
- plan=plan, # 应用自动生成的分片计划
- )
复制代码 总结
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |