【论文10】复现代码tips

打印 上一主题 下一主题

主题 917|帖子 917|积分 2751

一、预备工作

1.创建一个虚拟环境

  1. conda create --name drgcnn38 python=3.8.18
复制代码

2.激活虚拟环境

  1. conda activate drgcnn38
复制代码

注意事项

   在Pycharm中终端(terminal)显示PS而不是虚拟环境base
  问题如下所示
  

  解决方法:shell路径改成cmd.exe
  

  重启终端显示虚拟环境
  

  3.安装torch

  1. conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 cpuonly -c pytorch
复制代码
安装一系列包
注意事项

   Pycharm远程毗连Linux服务器实现代码同步
  1.工具-->部署-->设置
  

  2.选择SFTP远程毗连,路径填与服务器要同步的路径地点
  

  二、代码学习

各部门的作用



  • eye_pre_process:视网膜眼底图像预处理惩罚模块。
  • Encoder:编码器练习模块。
  • modules:包罗模型结构、丧失函数和学习率低落计谋。
  • utils:包罗一些常用函数和评估指标。
  • BFFN:双眼特征融合网络练习模块。
  • CAM:种别注意力模块。
eye_pre_process

copy.py

  1. # 创建一个ArgumentParser对象,用于处理命令行参数  
  2. parser = argparse.ArgumentParser()  
  3.   
  4. # 添加一个命令行参数 '--image-folder',类型为字符串,默认值为 'D:/cv_paper/lesson/Dataset/ceshi'  
  5. # 这个参数用于指定输入图像的文件夹路径  
  6. parser.add_argument('--image-folder', type=str, default=r'D:/cv_paper/lesson/Dataset/ceshi')  
  7.   
  8. # 添加一个命令行参数 '--output-folder',类型为字符串,默认值为 'D:\cv_paper\lesson/Dataset/ceshi_output'  
  9. # 注意:这里路径中的反斜杠在不同的操作系统中可能需要特别注意,Python字符串中推荐使用原始字符串(r前缀)来避免转义字符的问题  
  10. # 这个参数用于指定输出结果的文件夹路径  
  11. parser.add_argument('--output-folder', type=str, default=r'D:\cv_paper\lesson/Dataset/ceshi_output')  
  12.   
  13. # 添加一个命令行参数 '--crop-size',类型为整数,默认值为512  
  14. # 这个参数用于指定图像裁剪的大小  
  15. parser.add_argument('--crop-size', type=int, default=512, help='crop size of image')  
  16.   
  17. # 添加一个命令行参数 '-n' 或 '--num-processes',类型为整数,默认值为8  
  18. # 这个参数用于指定处理任务时要使用的进程数  
  19. # '-n' 是 '--num-processes' 的简写形式,帮助信息说明了该参数的作用  
  20. parser.add_argument('-n', '--num-processes', type=int, default=8, help='number of processes to use')
复制代码
  1. # 转换一个包含多个任务的列表,每个任务由文件名、目标路径和裁剪大小组成  
  2. # 对于jobs列表中的每个任务(索引为j),它首先检查是否已经处理了100个任务(作为进度指示),然后调用convert函数来执行实际的图像转换。
  3. def convert_list(i, jobs):  
  4.     for j, job in enumerate(jobs):  
  5.         # 每处理100个任务打印一次进度  
  6.         if j % 100 == 0:  
  7.             print(f'worker{i} has finished {j} tasks.')  
  8.         # 解包任务元组并调用convert函数  
  9.         convert(*job)  
  10.   
  11. # 转换单个图像文件,包括模糊处理、裁剪和保存  
  12. def convert(fname, tgt_path, crop_size):  
  13.     img = Image.open(fname)  # 打开图像文件  
  14.   
  15.     blurred = img.filter(ImageFilter.BLUR)  # 应用模糊滤镜  
  16.     ba = np.array(blurred)  # 将图像转换为NumPy数组  
  17.     h, w, _ = ba.shape  # 获取图像的高度、宽度和通道数  
  18.   
  19.     # 尝试根据图像的亮度分布来识别前景区域  
  20.     if w > 1.2 * h:  
  21.         # 计算左右两侧的最大亮度值  
  22.         left_max = ba[:, :w // 32, :].max(axis=(0, 1)).astype(int)  
  23.         right_max = ba[:, -w // 32:, :].max(axis=(0, 1)).astype(int)  
  24.         max_bg = np.maximum(left_max, right_max)  
  25.   
  26.         foreground = (ba > max_bg + 10).astype(np.uint8)  # 识别前景区域  
  27.         bbox = Image.fromarray(foreground).getbbox()  # 获取前景区域的最小边界框  
  28.   
  29.         # 如果边界框太小或不存在,则打印消息并可能设置为None  
  30.         if bbox is None:  
  31.             print(f'No bounding box found for {fname} (???)')  
  32.         else:  
  33.             left, upper, right, lower = bbox  
  34.             if right - left < 0.8 * h or lower - upper < 0.8 * h:  
  35.                 print(f'Bounding box too small for {fname}')  
  36.                 bbox = None  
  37.     else:  
  38.         bbox = None  # 如果图像已经是合适的宽高比,则不尝试识别前景  
  39.   
  40.     # 如果未找到有效的边界框,则使用正方形边界框  
  41.     if bbox is None:  
  42.         bbox = square_bbox(img)  
  43.   
  44.     # 使用边界框裁剪图像,并调整大小  
  45.     cropped = img.crop(bbox)  
  46.     cropped = cropped.resize([crop_size, crop_size], Image.ANTIALIAS)  # 注意:ANTIALIAS可能是个拼写错误,应该是ANTIALIASIS  
  47.     save(cropped, tgt_path)  # 保存图像  
  48.   
  49. # 返回一个正方形裁剪框的边界  
  50. def square_bbox(img):  
  51.     w, h = img.size  
  52.     left = max((w - h) // 2, 0)  
  53.     upper = 0  
  54.     right = min(w - (w - h) // 2, w)  
  55.     lower = h  
  56.     return (left, upper, right, lower)  
  57.   
  58. # 保存PIL图像到文件  
  59. def save(img, fname):  
  60.     img.save(fname, quality=100, subsampling=0)  # 注意:subsampling参数可能不是所有格式都支持  
  61.   
  62. # 假设的main函数,用于组织整个流程(注意:这里只是一个示例)  
  63. def main():  
  64.     # 示例任务列表,每个任务是一个(文件名, 目标路径, 裁剪大小)元组  
  65.     jobs = [  
  66.         ('input1.jpg', 'output1_resized.jpg', 256),  
  67.         ('input2.jpg', 'output2_resized.jpg', 256),  
  68.         # ... 更多任务  
  69.     ]  
  70.       
  71.     # 假设有一个工作者ID为1  
  72.     convert_list(1, jobs)  
  73.   
  74. if __name__ == "__main__":  
  75.     main()  
复制代码
Encoder

main.py

  1. # 定义主函数入口  
  2. def main():  
  3.     # 解析配置参数  
  4.     args = parse_configuration()  
  5.     # 加载配置文件  
  6.     cfg = load_config(args.config)  
  7.     # 获取配置中保存的路径  
  8.     save_path = cfg.config_base.config_save_path  
  9.     # 如果保存路径不存在,则创建该路径  
  10.     if not os.path.exists(save_path):  
  11.         os.makedirs(save_path)  
  12.     # 将配置文件复制到保存路径  
  13.     copy_config(args.config, cfg.config_base.config_save_path)  
  14.     # 执行工作函数  
  15.     worker(cfg)  
  16.   
  17. # 定义工作函数,负责训练、验证和测试模型  
  18. def worker(cfg):  
  19.     # 根据配置生成模型  
  20.     model = generate_model(cfg)  
  21.     # 计算模型总参数数量  
  22.     total_param = 0  
  23.     for param in model.parameters():  
  24.         total_param += param.numel()  
  25.     print("Parameter: %.2fM" % (total_param / 1e6))  # 打印模型参数数量(单位:百万)  
  26.     # 根据配置生成训练、验证和测试数据集  
  27.     train_dataset, test_dataset, val_dataset = generate_dataset(cfg)  
  28.     # 初始化性能评估器  
  29.     estimator = PerformanceEvaluator(cfg.config_train.config_criterion, cfg.config_data.config_num_classes)  
  30.     # 执行训练过程  
  31.     train(  
  32.         cfg=cfg,  
  33.         model=model,  
  34.         train_dataset=train_dataset,  
  35.         val_dataset=val_dataset,  
  36.         estimator=estimator,  
  37.     )  
  38.   
  39.     # 测试最佳验证模型性能  
  40.     print('This is the performance of the best validation model:')  
  41.     checkpoint = os.path.join(cfg.config_base.config_save_path, 'best_validation_weights.pt')  
  42.     cfg.config_train.config_checkpoint = checkpoint  # 设置检查点路径为最佳验证模型  
  43.     model = generate_model(cfg)  # 重新生成模型以加载权重  
  44.     evaluate(cfg, model, test_dataset, estimator)  # 评估模型性能  
  45.   
  46.     # 测试最终模型性能  
  47.     print('This is the performance of the final model:')  
  48.     checkpoint = os.path.join(cfg.config_base.config_save_path, 'final_weights.pt')  
  49.     cfg.config_train.config_checkpoint = checkpoint  # 设置检查点路径为最终模型  
  50.     model = generate_model(cfg)  # 重新生成模型以加载权重  
  51.     evaluate(cfg, model, test_dataset, estimator)  # 评估模型性能  
  52.   
  53. # 如果此脚本作为主程序运行,则调用main函数  
  54. if __name__ == '__main__':  
  55.     main()
复制代码
Encoder_predict.py

举行模型的练习,详细来说,它定义了一个练习循环&#x

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

道家人

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表