20250412_代码笔记_CVRProblemDef

打印 上一主题 下一主题

主题 1910|帖子 1910|积分 5730

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x

前言

该笔记分析代码的功能是生成随机VRP问题的数据,包含仓库坐标、节点坐标和节点需求。
对该代码举行改进
20250412-代码改进-拟蒙特卡洛

一、get_random_problems 函数分析

  1. depot_xy = torch.rand(size=(batch_size, 1, 2))
复制代码


  • 生成仓库坐标:

    • 生成形状为(batch_size, 1, 2) 的随机张量,表现每个批次中仓库的二维坐标(范围 [0,1))。

  1. node_xy = torch.rand(size=(batch_size, problem_size, 2))
复制代码


  • 生成节点坐标:

    • 生成形状为 (batch_size, problem_size, 2) 的随机张量,表现每个批次中全部节点的二维坐标。

  1. if problem_size == 20:
  2.     demand_scaler = 30
  3. elif problem_size == 50:
  4.     demand_scaler = 40
  5. elif problem_size == 100:
  6.     demand_scaler = 50
  7. node_demand = torch.randint(1, 10, size=(batch_size, problem_size)) / demand_scaler
复制代码


  • 生成节点需求:

    • 根据 problem_size 选择缩放因子 demand_scaler。
    • 生成 1~9 的整数需求,并缩放到 [1/50, 9/50] 等区间,确保需求值为浮点数。

二、augment_xy_data_by_8_fold 函数分析

功能:通过8种几何变更对坐标数据举行增强,扩凑数据集。
  1. x = xy_data[:, :, [0]]  # 提取x坐标
  2. y = xy_data[:, :, [1]]  # 提取y坐标
复制代码


  • 拆分坐标:

    • 从输入数据 xy_data(形状 (batch, N, 2))分离出x和y分量。

  1. dat1 = torch.cat((x, y), dim=2)          # 原始坐标
  2. dat2 = torch.cat((1 - x, y), dim=2)      # x轴镜像
  3. dat3 = torch.cat((x, 1 - y), dim=2)      # y轴镜像
  4. dat4 = torch.cat((1 - x, 1 - y), dim=2)  # x+y轴镜像
  5. dat5 = torch.cat((y, x), dim=2)          # 转置坐标
  6. dat6 = torch.cat((1 - y, x), dim=2)      # 转置后x轴镜像
  7. dat7 = torch.cat((y, 1 - x), dim=2)      # 转置后y轴镜像
  8. dat8 = torch.cat((1 - y, 1 - x), dim=2)  # 转置后x+y轴镜像
复制代码


  • 生成8种变更:

    • 对坐标举行镜像翻转和转置操纵,生成8种变体。

  1. aug_xy_data = torch.cat((dat1, dat2, ..., dat8), dim=0)
复制代码


  • 归并增强数据:
  • 将8种变更后的数据沿批次维度拼接,终极形状为 (8*batch, N, 2)。

代码

  1. import torchimport numpy as npdef get_random_problems(batch_size, problem_size):    depot_xy = torch.rand(size=(batch_size, 1, 2))
  2.     # shape: (batch, 1, 2)    node_xy = torch.rand(size=(batch_size, problem_size, 2))
  3.     # shape: (batch, problem, 2)    if problem_size == 20:        demand_scaler = 30    elif problem_size == 50:        demand_scaler = 40    elif problem_size == 100:        demand_scaler = 50    else:        raise NotImplementedError    node_demand = torch.randint(1, 10, size=(batch_size, problem_size)) / float(demand_scaler)    # shape: (batch, problem)    return depot_xy, node_xy, node_demanddef augment_xy_data_by_8_fold(xy_data):    # xy_data.shape: (batch, N, 2)    x = xy_data[:, :, [0]]    y = xy_data[:, :, [1]]    # x,y shape: (batch, N, 1)    dat1 = torch.cat((x, y), dim=2)    dat2 = torch.cat((1 - x, y), dim=2)    dat3 = torch.cat((x, 1 - y), dim=2)    dat4 = torch.cat((1 - x, 1 - y), dim=2)    dat5 = torch.cat((y, x), dim=2)    dat6 = torch.cat((1 - y, x), dim=2)    dat7 = torch.cat((y, 1 - x), dim=2)    dat8 = torch.cat((1 - y, 1 - x), dim=2)    aug_xy_data = torch.cat((dat1, dat2, dat3, dat4, dat5, dat6, dat7, dat8), dim=0)    # shape: (8*batch, N, 2)    return aug_xy_data
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

忿忿的泥巴坨

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表