基于迁徙学习实现肺炎X光片诊断分类
各人好,我是带我去滑雪!肺炎是举世范围内致死率较高的疾病之一,尤其是在老年人、免疫体系较弱的患者群体中,更容易引发严肃并发症。传统上,肺炎的诊断依赖于医生的临床履历以及影像学查抄,尤其是X光片,它在肺炎的早期筛查和诊断中扮演了至关重要的角色。然而,X光片的读取不仅必要专业的放射科医生,而且受到履历和疲劳等因素的影响,导致诊断结果的正确性存在一定的毛病。近年来,人工智能(AI)技能,尤其是深度学习在医学影像领域取得了明显盼望。通过深度学习模型,盘算机能够高效地从大量影像数据中学习到复杂的模式,并实现对疾病的主动辨认和分类,极大地进步了诊断的速度和正确性。迁徙学习作为深度学习的一种重要方法,能够通过在已有的、大规模的医学图像数据上预练习模型,并迁徙到肺炎X光片的分类任务上,减少对大量标注数据的需求,这对资源有限、标注困难的医学领域尤为重要。
https://i-blog.csdnimg.cn/direct/98cf6291270b4ef7bd2aafff387f1723.png
基于迁徙学习的肺炎X光片诊断分类研究,不仅可以缓解医生在实际工作中因繁重工作负担导致的诊断错误问题,还能够通过高效、正确的主动化诊断方法,在早期筛查中提供帮助,尤其是在偏远地域或医疗资源匮乏的环境中,为患者提供及时的诊疗建议,极大地促进了医疗资源的公道分配。此外,该研究的乐成实现还可以为其他疾病的X光片图像诊断提供鉴戒,推动人工智能技能在医学领域的广泛应用。下面开始代码实战。
目次
(1)导入相干模块
(2)构建数据集
(3)加载练习的网络
(4)调解模型
(5)设置测试集加载参数
(1)导入相干模块
import os
from PIL import Image
from glob import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
(2)构建数据集
class ChestXRayDataset(Dataset):
def __init__(
self,
dataset_dir,
transform=None) -> None:
self.dataset_dir = dataset_dir
self.transform = transform
# 获取文件夹下所有图片路径
self.dataset_images = glob(f"{self.dataset_dir}/**/*.jpeg", recursive=True)
# 获取数据集大小
def __len__(self):
return len(self.dataset_images)
# 读取图像,获取类别
def __getitem__(self, idx):
image_path = self.dataset_images
image_name = os.path.basename(image_path)
image = Image.open(image_path)
if "NORMAL" in image_name:
category = 0
else:
category = 1
if self.transform:
image = self.transform(image)
return image, category
(3)加载练习的网络
def prepare_model():
# 加载预训练的模型
resnet50_weight = ResNet50_Weights.DEFAULT
resnet50_mdl = resnet50(weights=resnet50_weight)
# 替换模型最后的全连接层
num_ftrs = resnet50_mdl.fc.in_features
resnet50_mdl.fc = nn.Linear(num_ftrs, 2)
return resnet50_mdl
def train_model():
# 确定使用CPU还是GPU
if torch.cuda.is_available():
device = "cuda:0"
else:
device = "cpu"
# 加载模型
model = prepare_model()
model = model.to(device)
model.train()
# 设置loss函数和optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 设置训练集数据加载相关变量
batch_size = 32
chest_xray = r"E:\工作\硕士\博客\博客99-深度学习医学特征提取\deeplea test\deeplea test\archive\chest_xray"
train_dataset_dir = os.path.join(chest_xray, "train")
train_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0) == 1 else x),
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.Normalize(, )
])
train_dataset = ChestXRayDataset(train_dataset_dir, train_transforms)
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True) (4)调解模型
for epoch in range(5):
print_batch = 50
running_loss = 0
running_corrects = 0
for i, data in enumerate(train_dataloader):
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += (loss.item() * batch_size)
running_corrects += torch.sum(preds == labels.data)
if i % print_batch == (print_batch - 1):# print every 100 mini-batches
accuracy = running_corrects / (print_batch * batch_size)
print(
f'Epoch: {epoch + 1}, Batch: {i + 1:5d} Running Loss: {running_loss / 50:.3f} Accuracy: {accuracy:.3f}')
running_loss = 0.0
running_corrects = 0
checkpoint_name = f"epoch_{epoch}.pth"
torch.save(model.state_dict(), checkpoint_name)
def test_model():
if torch.cuda.is_available():
device = "cuda:0"
else:
device = "cpu"
# 加载模型
checkpoint_name = "epoch_4.pth"
model = prepare_model()
model.load_state_dict(torch.load(checkpoint_name, map_location=device))
model = model.to(device)
model.eval()
(5)设置测试集加载参数
batch_size = 32
chest_xray = r"E:\工作\硕士\博客\博客99-深度学习医学特征提取\deeplea test\deeplea test\archive\chest_xray"
test_dataset_dir = os.path.join(chest_xray, "test")
test_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0) == 1 else x),
transforms.Resize((224, 224)),
transforms.Normalize(, )
])
test_dataset = ChestXRayDataset(test_dataset_dir, test_transforms)
test_dataloader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False)
# 在测试集测试模型
with torch.no_grad():
preds_list = []
labels_list = []
for i, data in enumerate(test_dataloader):
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
preds_list.append(preds)
labels_list.append(labels)
preds = torch.cat(preds_list)
labels = torch.cat(labels_list)
# 计算评价指标
corrects_num = torch.sum(preds == labels.data)
accuracy = corrects_num / labels.shape
# 输出评价指标
print(f"Accuracy on test dataset: {accuracy:.2%}")
if __name__ == "__main__":
train_model()
test_model()
输出结果:
https://i-blog.csdnimg.cn/direct/b249cbb0ea3d481586f58e2d1b701a21.png
更多优质内容一连发布中,请移步主页查看。
若有问题可邮箱接洽:1736732074@qq.com
博主的WeChat:TCB1736732074
点赞+关注,下次不迷路!
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
页:
[1]