20250412_代码笔记_CVRProblemDef
前言该笔记分析代码的功能是生成随机VRP问题的数据,包含仓库坐标、节点坐标和节点需求。
对该代码举行改进
20250412-代码改进-拟蒙特卡洛
一、get_random_problems 函数分析
depot_xy = torch.rand(size=(batch_size, 1, 2))
[*]生成仓库坐标:
[*]生成形状为(batch_size, 1, 2) 的随机张量,表现每个批次中仓库的二维坐标(范围 [0,1))。
node_xy = torch.rand(size=(batch_size, problem_size, 2))
[*]生成节点坐标:
[*]生成形状为 (batch_size, problem_size, 2) 的随机张量,表现每个批次中全部节点的二维坐标。
if problem_size == 20:
demand_scaler = 30
elif problem_size == 50:
demand_scaler = 40
elif problem_size == 100:
demand_scaler = 50
node_demand = torch.randint(1, 10, size=(batch_size, problem_size)) / demand_scaler
[*]生成节点需求:
[*]根据 problem_size 选择缩放因子 demand_scaler。
[*]生成 1~9 的整数需求,并缩放到 等区间,确保需求值为浮点数。
二、augment_xy_data_by_8_fold 函数分析
功能:通过8种几何变更对坐标数据举行增强,扩凑数据集。
x = xy_data[:, :, ]# 提取x坐标
y = xy_data[:, :, ]# 提取y坐标
[*]拆分坐标:
[*]从输入数据 xy_data(形状 (batch, N, 2))分离出x和y分量。
dat1 = torch.cat((x, y), dim=2) # 原始坐标
dat2 = torch.cat((1 - x, y), dim=2) # x轴镜像
dat3 = torch.cat((x, 1 - y), dim=2) # y轴镜像
dat4 = torch.cat((1 - x, 1 - y), dim=2)# x+y轴镜像
dat5 = torch.cat((y, x), dim=2) # 转置坐标
dat6 = torch.cat((1 - y, x), dim=2) # 转置后x轴镜像
dat7 = torch.cat((y, 1 - x), dim=2) # 转置后y轴镜像
dat8 = torch.cat((1 - y, 1 - x), dim=2)# 转置后x+y轴镜像
[*]生成8种变更:
[*]对坐标举行镜像翻转和转置操纵,生成8种变体。
aug_xy_data = torch.cat((dat1, dat2, ..., dat8), dim=0)
[*]归并增强数据:
[*]将8种变更后的数据沿批次维度拼接,终极形状为 (8*batch, N, 2)。
代码
import torchimport numpy as npdef get_random_problems(batch_size, problem_size): depot_xy = torch.rand(size=(batch_size, 1, 2))
# shape: (batch, 1, 2) node_xy = torch.rand(size=(batch_size, problem_size, 2))
# shape: (batch, problem, 2) if problem_size == 20: demand_scaler = 30 elif problem_size == 50: demand_scaler = 40 elif problem_size == 100: demand_scaler = 50 else: raise NotImplementedError node_demand = torch.randint(1, 10, size=(batch_size, problem_size)) / float(demand_scaler) # shape: (batch, problem) return depot_xy, node_xy, node_demanddef augment_xy_data_by_8_fold(xy_data): # xy_data.shape: (batch, N, 2) x = xy_data[:, :, ] y = xy_data[:, :, ] # x,y shape: (batch, N, 1) dat1 = torch.cat((x, y), dim=2) dat2 = torch.cat((1 - x, y), dim=2) dat3 = torch.cat((x, 1 - y), dim=2) dat4 = torch.cat((1 - x, 1 - y), dim=2) dat5 = torch.cat((y, x), dim=2) dat6 = torch.cat((1 - y, x), dim=2) dat7 = torch.cat((y, 1 - x), dim=2) dat8 = torch.cat((1 - y, 1 - x), dim=2) aug_xy_data = torch.cat((dat1, dat2, dat3, dat4, dat5, dat6, dat7, dat8), dim=0) # shape: (8*batch, N, 2) return aug_xy_data
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
页:
[1]