马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。
您需要 登录 才可以下载或查看,没有账号?立即注册
x
前言
该笔记分析代码的功能是生成随机VRP问题的数据,包含仓库坐标、节点坐标和节点需求。
对该代码举行改进
20250412-代码改进-拟蒙特卡洛
一、get_random_problems 函数分析
- depot_xy = torch.rand(size=(batch_size, 1, 2))
复制代码
- 生成仓库坐标:
- 生成形状为(batch_size, 1, 2) 的随机张量,表现每个批次中仓库的二维坐标(范围 [0,1))。
- node_xy = torch.rand(size=(batch_size, problem_size, 2))
复制代码
- 生成节点坐标:
- 生成形状为 (batch_size, problem_size, 2) 的随机张量,表现每个批次中全部节点的二维坐标。
- if problem_size == 20:
- demand_scaler = 30
- elif problem_size == 50:
- demand_scaler = 40
- elif problem_size == 100:
- demand_scaler = 50
- node_demand = torch.randint(1, 10, size=(batch_size, problem_size)) / demand_scaler
复制代码
- 生成节点需求:
- 根据 problem_size 选择缩放因子 demand_scaler。
- 生成 1~9 的整数需求,并缩放到 [1/50, 9/50] 等区间,确保需求值为浮点数。
二、augment_xy_data_by_8_fold 函数分析
功能:通过8种几何变更对坐标数据举行增强,扩凑数据集。
- x = xy_data[:, :, [0]] # 提取x坐标
- y = xy_data[:, :, [1]] # 提取y坐标
复制代码
- 拆分坐标:
- 从输入数据 xy_data(形状 (batch, N, 2))分离出x和y分量。
- dat1 = torch.cat((x, y), dim=2) # 原始坐标
- dat2 = torch.cat((1 - x, y), dim=2) # x轴镜像
- dat3 = torch.cat((x, 1 - y), dim=2) # y轴镜像
- dat4 = torch.cat((1 - x, 1 - y), dim=2) # x+y轴镜像
- dat5 = torch.cat((y, x), dim=2) # 转置坐标
- dat6 = torch.cat((1 - y, x), dim=2) # 转置后x轴镜像
- dat7 = torch.cat((y, 1 - x), dim=2) # 转置后y轴镜像
- dat8 = torch.cat((1 - y, 1 - x), dim=2) # 转置后x+y轴镜像
复制代码
- aug_xy_data = torch.cat((dat1, dat2, ..., dat8), dim=0)
复制代码
- 归并增强数据:
- 将8种变更后的数据沿批次维度拼接,终极形状为 (8*batch, N, 2)。
代码
- import torchimport numpy as npdef get_random_problems(batch_size, problem_size): depot_xy = torch.rand(size=(batch_size, 1, 2))
- # shape: (batch, 1, 2) node_xy = torch.rand(size=(batch_size, problem_size, 2))
- # shape: (batch, problem, 2) if problem_size == 20: demand_scaler = 30 elif problem_size == 50: demand_scaler = 40 elif problem_size == 100: demand_scaler = 50 else: raise NotImplementedError node_demand = torch.randint(1, 10, size=(batch_size, problem_size)) / float(demand_scaler) # shape: (batch, problem) return depot_xy, node_xy, node_demanddef augment_xy_data_by_8_fold(xy_data): # xy_data.shape: (batch, N, 2) x = xy_data[:, :, [0]] y = xy_data[:, :, [1]] # x,y shape: (batch, N, 1) dat1 = torch.cat((x, y), dim=2) dat2 = torch.cat((1 - x, y), dim=2) dat3 = torch.cat((x, 1 - y), dim=2) dat4 = torch.cat((1 - x, 1 - y), dim=2) dat5 = torch.cat((y, x), dim=2) dat6 = torch.cat((1 - y, x), dim=2) dat7 = torch.cat((y, 1 - x), dim=2) dat8 = torch.cat((1 - y, 1 - x), dim=2) aug_xy_data = torch.cat((dat1, dat2, dat3, dat4, dat5, dat6, dat7, dat8), dim=0) # shape: (8*batch, N, 2) return aug_xy_data
复制代码 免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |