六、分布式嵌入

打印 上一主题 下一主题

主题 1760|帖子 1760|积分 5280

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
六、分布式嵌入



  

前言



  • 我们已经利用了TorchRec的主模块:EmbeddedBagCollection。我们在上一节研究了它是如何工作的,以及数据在TorchRec中是如何表示的。然而,我们还没有探索TorchRec的重要部分之一,即分布式嵌入

一、先要设置torch.distributed环境



  • EmbeddingBagCollectionSharder 依赖于 PyTorch 的分布式通讯库(torch.distributed)来管理跨进程/GPU 的分片和通讯。
   首先初始化分布式环境
  1. import torch.distributed as dist
  2. # 初始化进程组
  3. dist.init_process_group(
  4.     backend="nccl",          # GPU 推荐 NCCL 后端, CPU就是 gloo
  5.     init_method="env://",    # 从环境变量读取节点信息
  6.     rank=rank,               # 当前进程的全局唯一标识(从 0 开始)
  7.     world_size=world_size,   # 总进程数(总 GPU 数)
  8. )
  9. pg = dist.GroupMember.WORLD
复制代码
  设置环境变量(多节点训练时必须)
  1. import torch.distributed as dist
  2. # 初始化进程组
  3. # 在每个节点上设置以下环境变量
  4. export MASTER_ADDR="主节点IP"   # 如 "192.168.1.1"
  5. export MASTER_PORT="66666"     # 任意未占用端口
  6. export WORLD_SIZE=4            # 总 GPU 数
  7. export RANK=0                  # 当前节点的全局 rank
复制代码
二、Distributed Embeddings



  • 先回顾一下我们上一节的EmbeddingBagCollection module
   代码演示:
  1. print(ebc)
  2. """
  3. EmbeddingBagCollection(
  4.   (embedding_bags): ModuleDict(
  5.     (product_table): EmbeddingBag(4096, 64, mode='sum')
  6.     (user_table): EmbeddingBag(4096, 64, mode='sum')
  7.   )
  8. )
  9. """
复制代码
2.1 EmbeddingBagCollectionSharder



  • 策略制定者 ,决定如何分片。
  • 决定如何将 EmbeddingBagCollection 的嵌入表(Embedding Tables)分布到多个 GPU/节点。
    核心功能 :根据设置(如 ShardingType)天生分片计划(Sharding Plan
   代码演示:
  1. from torchrec.distributed.embedding_types import ShardingType
  2. from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
  3. # 定义分片器:指定分片策略(如按表分片)
  4. sharder = EmbeddingBagCollectionSharder(
  5.     sharding_type=ShardingType.TABLE_WISE.value,  # 每个表分配到一个 GPU
  6.     kernel_type=EmbeddingComputeKernel.FUSED.value,  # 使用 fused 优化
  7. )
复制代码


  • 关键参数

    • sharding_type:分片策略,如:

      • TABLE_WISE:整个表放在一个 GPU。
      • ROW_WISE:按行分片到多个 GPU。
      • COLUMN_WISE:按列分片(适用于超大表)。

    • kernel_type:盘算内核范例(如 FUSED 优化显存)

2.2 ShardedEmbeddingBagCollection



  • 策略执行者 ,实际管理分片后的嵌入表
  • 根据 EmbeddingBagCollectionSharder 天生的分片计划,实际管理分布在多设备上的嵌入表。
  • 核心功能 :在分布式环境中执行前向传播、梯度聚合和参数更新
   代码演示:
  1. from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
  2. # 根据分片器生成分片后的模块
  3. sharded_ebc = ShardedEmbeddingBagCollection(
  4.     module=ebc,        # 原始 EmbeddingBagCollection
  5.     sharder=sharder,   # 分片策略
  6.     device=device,     # 目标设备(如 GPU:0)
  7. )
复制代码
三、Planner



  • 它可以帮助我们确定最佳的分片设置。
  • Planner可以大概根据嵌入表的数量和GPU的数量来确定最佳设置。事实证明,这很难手动完成,工程师必须考虑大量因素来确保最佳的分片计划。
  • TorchRec在提供的这个Planner,可以帮助我们:

    • 评估硬件的内存限制
    • 将基于存储器获取的盘算估计为嵌入查找
    • 解决数据特定因素
    • 考虑其他硬件细节,如带宽,以天生最佳分片计划

   演示代码:
  1. from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
  2. # 初始化Planner
  3. planner = EmbeddingShardingPlanner(
  4.     topology=Topology(  # 硬件拓扑信息
  5.         world_size=4,  # 总 GPU 数
  6.         compute_device="cuda",
  7.         local_world_size=2,  # 单机 GPU 数
  8.         batch_size=1024,  
  9.     ),
  10.     constraints={  # 可选约束(如强制某些表使用特定策略)
  11.         "user_id": ParameterConstraints(sharding_types=[ShardingType.TABLE_WISE]),
  12.     },
  13. )
  14. # 生成分片计划
  15. plan = planner.collective_plan(ebc, [sharder], pg)
  16. # 分片后的模型
  17. from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
  18. sharded_ebc = ShardedEmbeddingBagCollection(
  19.     module=ebc,
  20.     sharder=sharder,
  21.     device=torch.device("cuda:0"),
  22.     plan=plan,  # 应用自动生成的分片计划
  23. )
复制代码

总结



  • TorchRec中的分布式嵌入以及训练设置。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

天津储鑫盛钢材现货供应商

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表