IT评测·应用市场-qidao123.com技术社区

标题: 六、分布式嵌入 [打印本页]

作者: 天津储鑫盛钢材现货供应商    时间: 2025-4-18 15:28
标题: 六、分布式嵌入
六、分布式嵌入



  

前言



一、先要设置torch.distributed环境


   首先初始化分布式环境
  1. import torch.distributed as dist
  2. # 初始化进程组
  3. dist.init_process_group(
  4.     backend="nccl",          # GPU 推荐 NCCL 后端, CPU就是 gloo
  5.     init_method="env://",    # 从环境变量读取节点信息
  6.     rank=rank,               # 当前进程的全局唯一标识(从 0 开始)
  7.     world_size=world_size,   # 总进程数(总 GPU 数)
  8. )
  9. pg = dist.GroupMember.WORLD
复制代码
  设置环境变量(多节点训练时必须)
  1. import torch.distributed as dist
  2. # 初始化进程组
  3. # 在每个节点上设置以下环境变量
  4. export MASTER_ADDR="主节点IP"   # 如 "192.168.1.1"
  5. export MASTER_PORT="66666"     # 任意未占用端口
  6. export WORLD_SIZE=4            # 总 GPU 数
  7. export RANK=0                  # 当前节点的全局 rank
复制代码
二、Distributed Embeddings


   代码演示:
  1. print(ebc)
  2. """
  3. EmbeddingBagCollection(
  4.   (embedding_bags): ModuleDict(
  5.     (product_table): EmbeddingBag(4096, 64, mode='sum')
  6.     (user_table): EmbeddingBag(4096, 64, mode='sum')
  7.   )
  8. )
  9. """
复制代码
2.1 EmbeddingBagCollectionSharder


   代码演示:
  1. from torchrec.distributed.embedding_types import ShardingType
  2. from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
  3. # 定义分片器:指定分片策略(如按表分片)
  4. sharder = EmbeddingBagCollectionSharder(
  5.     sharding_type=ShardingType.TABLE_WISE.value,  # 每个表分配到一个 GPU
  6.     kernel_type=EmbeddingComputeKernel.FUSED.value,  # 使用 fused 优化
  7. )
复制代码

2.2 ShardedEmbeddingBagCollection


   代码演示:
  1. from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
  2. # 根据分片器生成分片后的模块
  3. sharded_ebc = ShardedEmbeddingBagCollection(
  4.     module=ebc,        # 原始 EmbeddingBagCollection
  5.     sharder=sharder,   # 分片策略
  6.     device=device,     # 目标设备(如 GPU:0)
  7. )
复制代码
三、Planner


   演示代码:
  1. from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
  2. # 初始化Planner
  3. planner = EmbeddingShardingPlanner(
  4.     topology=Topology(  # 硬件拓扑信息
  5.         world_size=4,  # 总 GPU 数
  6.         compute_device="cuda",
  7.         local_world_size=2,  # 单机 GPU 数
  8.         batch_size=1024,  
  9.     ),
  10.     constraints={  # 可选约束(如强制某些表使用特定策略)
  11.         "user_id": ParameterConstraints(sharding_types=[ShardingType.TABLE_WISE]),
  12.     },
  13. )
  14. # 生成分片计划
  15. plan = planner.collective_plan(ebc, [sharder], pg)
  16. # 分片后的模型
  17. from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
  18. sharded_ebc = ShardedEmbeddingBagCollection(
  19.     module=ebc,
  20.     sharder=sharder,
  21.     device=torch.device("cuda:0"),
  22.     plan=plan,  # 应用自动生成的分片计划
  23. )
复制代码

总结



免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 IT评测·应用市场-qidao123.com技术社区 (https://dis.qidao123.com/) Powered by Discuz! X3.4