3DUnet实现3D医学影像的有效分割

宁睿  论坛元老 | 2024-8-23 14:21:20 | 显示全部楼层 | 阅读模式
打印 上一主题 下一主题

主题 1082|帖子 1082|积分 3256

最近涉及到了3D医学影像的分割,网络上相关的实现比力少,因此举行实现记录。

  
1.设置代码环境

这里介绍一个很好的开源项目,git为: https://github.com/ellisdg/3DUnetCNN.git。
安装环境为:
  1. nibabel>=4.0.1
  2. numpy>=1.23.0
  3. #torch>=1.12.0
  4. monai>=1.2.0
  5. scipy>=1.9.0
  6. pandas>=1.4.3
  7. nilearn>=0.9.1
  8. pillow>=9.3.0
复制代码
这里以Conda为例,很慢的话,可以-i 清湖镜像源:
  1. conda create -n 3DUnet python=3.8
复制代码
  1. conda activate 3DUnet
复制代码
  1. git clone https://github.com/ellisdg/3DUnetCNN.git
  2. cd 3DUnetCNN
  3. pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
复制代码
2.设置数据集以及模型文件

这里以 examples/brats2020/brats2020_config.json的json设置文件为例。

在同级文件夹下创建我们的任务的设置文件。随后,对json文件中比力重要的参数举行说明

json文件中,in_channels表现模型的输入通道,out_channels表现模型的输出通道数。

dataset字典中,desired_shape就是颠末monai数据预处理处罚库处理处罚后的图片的WxHxC同一到128x128x128。labels就是数据集有几类,假如只有1类,那就只有0(背景)、1(前景之分)。该项目会将label转换成one hot编码。

接下来,比力重要的参数就是training_filenames,其中分为image和label,也就是图像以及其对应的标签。

bratsvalidation_filenames就是测试集,只存图片image。

由于这是参考的 examples/brats2020/brats2020_config.json的大脑分割json设置文件,这里的数据集路径要更换为我们的,这里这里附上更改我们自己的数据集的脚本代码。
  1. import json
  2. import os.path
  3. '''将自己的数据集进行划分并添加到json配置文件'''
  4. #原json文件路径
  5. filename = r'D:\jiedan\3DUnetCNN\examples\tooth_me\tooth_me_config.json'
  6. #自己数据集的图片路径
  7. my_data_dir = r'D:\jiedan\tooth_segmentation\image'
  8. #自己数据集的label路径
  9. my_data_label_dir = r'D:\jiedan\tooth_segmentation\label_32_pre'
  10. #进行数据集划分
  11. my_data_file = os.listdir(my_data_dir)
  12. train_num, val_num = int(len(my_data_file) * 0.8), int(len(my_data_file) - len(my_data_file) * 0.8)
  13. train_data_file = my_data_file[:train_num]
  14. val_data_file = my_data_file[train_num:]
  15. with open(filename, 'r') as opened_file:
  16.     data = json.load(opened_file)
  17.     #这里因为读取的所参考的examples/brats2020/brats2020_config.json
  18.     #该数据集的图片数远远大于我们自己的数据集,所以只要截取到和我们的数据集一致的长度就行
  19.     train_file = data["training_filenames"][:train_num]
  20.     val_file = data["bratsvalidation_filenames"][:val_num]
  21.     for index, file in enumerate(train_file[:train_num]):
  22.         file["image"] = os.path.join(my_data_dir, train_data_file[index])
  23.         file["label"] = os.path.join(my_data_label_dir, train_data_file[index].replace('.nii.gz', '.nii'))
  24.     for index_v, j in enumerate(val_file[:val_num]):
  25.         images_val = j['image']
  26.         j['image'] = os.path.join(my_data_dir, val_data_file[index_v])
  27. #进行数据集的路径字典更新
  28. data["training_filenames"] = train_file
  29. data["bratsvalidation_filenames"] = val_file
  30. with open(filename, 'w') as opened_file:
  31.     json.dump(data, opened_file, indent=4)  # 使用indent参数格式化保存的JSON数据,以便更易于阅读
复制代码
3.练习

下面是 练习的脚本。
  1. python unet3d/scripts/train.py --config_filename
  2. ./examples/tooth_me/tooth_me_config.json
复制代码
<config_filename>指向我们刚才处理处罚好的我们自己的数据集以及模型的json文件。
4.预测

下面是 预测的脚本。
  1. python unet3d/scripts/train.py --config_filename
  2. ./examples/tooth_me/tooth_me_config.json
复制代码
<config_filename>指向我们刚才处理处罚好的我们自己的数据集以及模型的json文件。
由于该git项目预测仅仅只是通过利用练习好的权重初始化的模型来输出预测图像,格式与输入图像划一,为nii.gz。

这个预测脚本 predict.py是没有衡量指标的计算的,比如Dice分数。
我们参考monai的官方文档的class monai.metrics.CumulativeIterationMetric类。
下面为官方文档利用说明:
  1. dice_metric = DiceMetric(include_background=True, reduction="mean")
  2. for val_data in val_loader:
  3.     val_outputs = model(val_data["img"])
  4.     val_outputs = [postprocessing_transform(i) for i in decollate_batch(val_outputs)]
  5.     # compute metric for current iteration
  6.     dice_metric(y_pred=val_outputs, y=val_data["seg"])  # callable to add metric to the buffer
  7. # aggregate the final mean dice result
  8. metric = dice_metric.aggregate().item()
  9. # reset the status for next computation round
  10. dice_metric.reset()
复制代码
我们,首先定位到unet3d/scripts/predict.py,定位到 unet3d/predict/volumetric.py文件的volumetric_predictions函数。
  1. def volumetric_predictions(model, dataloader, prediction_dir, activation=None, resample=False,
  2.                            interpolation="trilinear", inferer=None):
  3.     output_filenames = list()
  4.     writer = NibabelWriter()
  5.     # 使用DiceMetric实例化metric对象
  6.     dice_metric = DiceMetric(include_background=True, reduction="mean")
  7.     ......
  8.     with torch.no_grad():
  9.         for idx, item in enumerate(dataloader):
  10.             x = item["image"]
  11.             x = x.to(next(model.parameters()).device)  # Set the input to the same device as the model parameters
  12.             .....
  13.             predictions = model(x)
  14.             batch_size = x.shape[0]
  15.             for batch_idx in range(batch_size):
  16.                 _prediction = predictions[batch_idx]
  17.                 _x = x[batch_idx]
  18.                 if resample:
  19.                     _x = loader(os.path.abspath(_x.meta["filename_or_obj"]))
  20.                     #在这里加上读取label的代码并转移到对应的device上
  21.                     _label = loader(os.path.abspath(_x.meta["filename_or_obj"]).replace('image', 'label_32_pre').replace('nii.gz', 'nii'))
  22.                     _label = _label.to(next(model.parameters()).device)  # Set the input to the same device as the model parameters
  23.                     _prediction = resampler(_prediction, _x)
  24.                     #将模型预测的输出与加代码读取的label送进去
  25.                     # compute metric for current iteration
  26.                     dice_metric(y_pred=_prediction, y=_label)  # callable to add metric to the buffe
  27.                 writer.set_data_array(_prediction)
  28.                 writer.set_metadata(_x.meta, resample=False)
  29.                 out_filename = os.path.join(prediction_dir,
  30.                                             os.path.basename(_x.meta["filename_or_obj"]).split(".")[0] + ".nii.gz")
  31.                 writer.write(out_filename, verbose=True)
  32.                 output_filenames.append(out_filename)
  33.         #最后求平均得到最终的Dice分数
  34.         # aggregate the final mean dice result
  35.         metric = dice_metric.aggregate().item()
  36.     return output_filenames
复制代码
还有许多衡量的评价指标,可以参考monai的官方文档:
  1. https://docs.monai.io/en/stable/metrics.html
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

宁睿

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表