忿忿的泥巴坨 发表于 2025-4-13 08:00:54

20250412_代码笔记_CVRProblemDef

前言

该笔记分析代码的功能是生成随机VRP问题的数据,包含仓库坐标、节点坐标和节点需求。
对该代码举行改进
20250412-代码改进-拟蒙特卡洛
一、get_random_problems 函数分析

depot_xy = torch.rand(size=(batch_size, 1, 2))


[*]生成仓库坐标:

[*]生成形状为(batch_size, 1, 2) 的随机张量,表现每个批次中仓库的二维坐标(范围 [0,1))。

node_xy = torch.rand(size=(batch_size, problem_size, 2))


[*]生成节点坐标:

[*]生成形状为 (batch_size, problem_size, 2) 的随机张量,表现每个批次中全部节点的二维坐标。

if problem_size == 20:
    demand_scaler = 30
elif problem_size == 50:
    demand_scaler = 40
elif problem_size == 100:
    demand_scaler = 50
node_demand = torch.randint(1, 10, size=(batch_size, problem_size)) / demand_scaler


[*]生成节点需求:

[*]根据 problem_size 选择缩放因子 demand_scaler。
[*]生成 1~9 的整数需求,并缩放到 等区间,确保需求值为浮点数。

二、augment_xy_data_by_8_fold 函数分析

功能:通过8种几何变更对坐标数据举行增强,扩凑数据集。
x = xy_data[:, :, ]# 提取x坐标
y = xy_data[:, :, ]# 提取y坐标


[*]拆分坐标:

[*]从输入数据 xy_data(形状 (batch, N, 2))分离出x和y分量。

dat1 = torch.cat((x, y), dim=2)          # 原始坐标
dat2 = torch.cat((1 - x, y), dim=2)      # x轴镜像
dat3 = torch.cat((x, 1 - y), dim=2)      # y轴镜像
dat4 = torch.cat((1 - x, 1 - y), dim=2)# x+y轴镜像
dat5 = torch.cat((y, x), dim=2)          # 转置坐标
dat6 = torch.cat((1 - y, x), dim=2)      # 转置后x轴镜像
dat7 = torch.cat((y, 1 - x), dim=2)      # 转置后y轴镜像
dat8 = torch.cat((1 - y, 1 - x), dim=2)# 转置后x+y轴镜像


[*]生成8种变更:

[*]对坐标举行镜像翻转和转置操纵,生成8种变体。

aug_xy_data = torch.cat((dat1, dat2, ..., dat8), dim=0)


[*]归并增强数据:
[*]将8种变更后的数据沿批次维度拼接,终极形状为 (8*batch, N, 2)。
代码

import torchimport numpy as npdef get_random_problems(batch_size, problem_size):    depot_xy = torch.rand(size=(batch_size, 1, 2))
    # shape: (batch, 1, 2)    node_xy = torch.rand(size=(batch_size, problem_size, 2))
    # shape: (batch, problem, 2)    if problem_size == 20:      demand_scaler = 30    elif problem_size == 50:      demand_scaler = 40    elif problem_size == 100:      demand_scaler = 50    else:      raise NotImplementedError    node_demand = torch.randint(1, 10, size=(batch_size, problem_size)) / float(demand_scaler)    # shape: (batch, problem)    return depot_xy, node_xy, node_demanddef augment_xy_data_by_8_fold(xy_data):    # xy_data.shape: (batch, N, 2)    x = xy_data[:, :, ]    y = xy_data[:, :, ]    # x,y shape: (batch, N, 1)    dat1 = torch.cat((x, y), dim=2)    dat2 = torch.cat((1 - x, y), dim=2)    dat3 = torch.cat((x, 1 - y), dim=2)    dat4 = torch.cat((1 - x, 1 - y), dim=2)    dat5 = torch.cat((y, x), dim=2)    dat6 = torch.cat((1 - y, x), dim=2)    dat7 = torch.cat((y, 1 - x), dim=2)    dat8 = torch.cat((1 - y, 1 - x), dim=2)    aug_xy_data = torch.cat((dat1, dat2, dat3, dat4, dat5, dat6, dat7, dat8), dim=0)    # shape: (8*batch, N, 2)    return aug_xy_data
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
页: [1]
查看完整版本: 20250412_代码笔记_CVRProblemDef