道家人 发表于 2024-9-24 13:31:52

【论文10】复现代码tips

一、预备工作

1.创建一个虚拟环境

conda create --name drgcnn38 python=3.8.18 https://i-blog.csdnimg.cn/direct/49661e28a2d54133856ab27117e99c63.png
2.激活虚拟环境

conda activate drgcnn38 https://i-blog.csdnimg.cn/direct/0c5e8fba719544389adbcb8e055b49de.png
注意事项

   在Pycharm中终端(terminal)显示PS而不是虚拟环境base
问题如下所示
https://i-blog.csdnimg.cn/direct/d7d9514ded8b47e2a99aadc39320525e.png
解决方法:shell路径改成cmd.exe
https://i-blog.csdnimg.cn/direct/a6df910e894540458e7dd46f11075a59.png
重启终端显示虚拟环境
https://i-blog.csdnimg.cn/direct/af228c7ecb324ae3b005f129fbb7c53c.png
3.安装torch

conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 cpuonly -c pytorch 安装一系列包
注意事项

   Pycharm远程毗连Linux服务器实现代码同步
1.工具-->部署-->设置
https://i-blog.csdnimg.cn/direct/5fc31afbec814d548d7c6cf687bccb5b.png
2.选择SFTP远程毗连,路径填与服务器要同步的路径地点
https://i-blog.csdnimg.cn/direct/bcbb5379c9c64a5b847d9bf927604b80.png
二、代码学习

各部门的作用



[*]eye_pre_process:视网膜眼底图像预处理惩罚模块。
[*]Encoder:编码器练习模块。
[*]modules:包罗模型结构、丧失函数和学习率低落计谋。
[*]utils:包罗一些常用函数和评估指标。
[*]BFFN:双眼特征融合网络练习模块。
[*]CAM:种别注意力模块。
eye_pre_process

copy.py

# 创建一个ArgumentParser对象,用于处理命令行参数
parser = argparse.ArgumentParser()

# 添加一个命令行参数 '--image-folder',类型为字符串,默认值为 'D:/cv_paper/lesson/Dataset/ceshi'
# 这个参数用于指定输入图像的文件夹路径
parser.add_argument('--image-folder', type=str, default=r'D:/cv_paper/lesson/Dataset/ceshi')

# 添加一个命令行参数 '--output-folder',类型为字符串,默认值为 'D:\cv_paper\lesson/Dataset/ceshi_output'
# 注意:这里路径中的反斜杠在不同的操作系统中可能需要特别注意,Python字符串中推荐使用原始字符串(r前缀)来避免转义字符的问题
# 这个参数用于指定输出结果的文件夹路径
parser.add_argument('--output-folder', type=str, default=r'D:\cv_paper\lesson/Dataset/ceshi_output')

# 添加一个命令行参数 '--crop-size',类型为整数,默认值为512
# 这个参数用于指定图像裁剪的大小
parser.add_argument('--crop-size', type=int, default=512, help='crop size of image')

# 添加一个命令行参数 '-n' 或 '--num-processes',类型为整数,默认值为8
# 这个参数用于指定处理任务时要使用的进程数
# '-n' 是 '--num-processes' 的简写形式,帮助信息说明了该参数的作用
parser.add_argument('-n', '--num-processes', type=int, default=8, help='number of processes to use') # 转换一个包含多个任务的列表,每个任务由文件名、目标路径和裁剪大小组成
# 对于jobs列表中的每个任务(索引为j),它首先检查是否已经处理了100个任务(作为进度指示),然后调用convert函数来执行实际的图像转换。
def convert_list(i, jobs):
    for j, job in enumerate(jobs):
      # 每处理100个任务打印一次进度
      if j % 100 == 0:
            print(f'worker{i} has finished {j} tasks.')
      # 解包任务元组并调用convert函数
      convert(*job)

# 转换单个图像文件,包括模糊处理、裁剪和保存
def convert(fname, tgt_path, crop_size):
    img = Image.open(fname)# 打开图像文件

    blurred = img.filter(ImageFilter.BLUR)# 应用模糊滤镜
    ba = np.array(blurred)# 将图像转换为NumPy数组
    h, w, _ = ba.shape# 获取图像的高度、宽度和通道数

    # 尝试根据图像的亮度分布来识别前景区域
    if w > 1.2 * h:
      # 计算左右两侧的最大亮度值
      left_max = ba[:, :w // 32, :].max(axis=(0, 1)).astype(int)
      right_max = ba[:, -w // 32:, :].max(axis=(0, 1)).astype(int)
      max_bg = np.maximum(left_max, right_max)

      foreground = (ba > max_bg + 10).astype(np.uint8)# 识别前景区域
      bbox = Image.fromarray(foreground).getbbox()# 获取前景区域的最小边界框

      # 如果边界框太小或不存在,则打印消息并可能设置为None
      if bbox is None:
            print(f'No bounding box found for {fname} (???)')
      else:
            left, upper, right, lower = bbox
            if right - left < 0.8 * h or lower - upper < 0.8 * h:
                print(f'Bounding box too small for {fname}')
                bbox = None
    else:
      bbox = None# 如果图像已经是合适的宽高比,则不尝试识别前景

    # 如果未找到有效的边界框,则使用正方形边界框
    if bbox is None:
      bbox = square_bbox(img)

    # 使用边界框裁剪图像,并调整大小
    cropped = img.crop(bbox)
    cropped = cropped.resize(, Image.ANTIALIAS)# 注意:ANTIALIAS可能是个拼写错误,应该是ANTIALIASIS
    save(cropped, tgt_path)# 保存图像

# 返回一个正方形裁剪框的边界
def square_bbox(img):
    w, h = img.size
    left = max((w - h) // 2, 0)
    upper = 0
    right = min(w - (w - h) // 2, w)
    lower = h
    return (left, upper, right, lower)

# 保存PIL图像到文件
def save(img, fname):
    img.save(fname, quality=100, subsampling=0)# 注意:subsampling参数可能不是所有格式都支持

# 假设的main函数,用于组织整个流程(注意:这里只是一个示例)
def main():
    # 示例任务列表,每个任务是一个(文件名, 目标路径, 裁剪大小)元组
    jobs = [
      ('input1.jpg', 'output1_resized.jpg', 256),
      ('input2.jpg', 'output2_resized.jpg', 256),
      # ... 更多任务
    ]
      
    # 假设有一个工作者ID为1
    convert_list(1, jobs)

if __name__ == "__main__":
    main() Encoder

main.py

# 定义主函数入口
def main():
    # 解析配置参数
    args = parse_configuration()
    # 加载配置文件
    cfg = load_config(args.config)
    # 获取配置中保存的路径
    save_path = cfg.config_base.config_save_path
    # 如果保存路径不存在,则创建该路径
    if not os.path.exists(save_path):
      os.makedirs(save_path)
    # 将配置文件复制到保存路径
    copy_config(args.config, cfg.config_base.config_save_path)
    # 执行工作函数
    worker(cfg)

# 定义工作函数,负责训练、验证和测试模型
def worker(cfg):
    # 根据配置生成模型
    model = generate_model(cfg)
    # 计算模型总参数数量
    total_param = 0
    for param in model.parameters():
      total_param += param.numel()
    print("Parameter: %.2fM" % (total_param / 1e6))# 打印模型参数数量(单位:百万)
    # 根据配置生成训练、验证和测试数据集
    train_dataset, test_dataset, val_dataset = generate_dataset(cfg)
    # 初始化性能评估器
    estimator = PerformanceEvaluator(cfg.config_train.config_criterion, cfg.config_data.config_num_classes)
    # 执行训练过程
    train(
      cfg=cfg,
      model=model,
      train_dataset=train_dataset,
      val_dataset=val_dataset,
      estimator=estimator,
    )

    # 测试最佳验证模型性能
    print('This is the performance of the best validation model:')
    checkpoint = os.path.join(cfg.config_base.config_save_path, 'best_validation_weights.pt')
    cfg.config_train.config_checkpoint = checkpoint# 设置检查点路径为最佳验证模型
    model = generate_model(cfg)# 重新生成模型以加载权重
    evaluate(cfg, model, test_dataset, estimator)# 评估模型性能

    # 测试最终模型性能
    print('This is the performance of the final model:')
    checkpoint = os.path.join(cfg.config_base.config_save_path, 'final_weights.pt')
    cfg.config_train.config_checkpoint = checkpoint# 设置检查点路径为最终模型
    model = generate_model(cfg)# 重新生成模型以加载权重
    evaluate(cfg, model, test_dataset, estimator)# 评估模型性能

# 如果此脚本作为主程序运行,则调用main函数
if __name__ == '__main__':
    main() Encoder_predict.py

举行模型的练习,详细来说,它定义了一个练习循环&#x

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
页: [1]
查看完整版本: 【论文10】复现代码tips