昇思25天学习打卡

打印 上一主题 下一主题

主题 1931|帖子 1931|积分 5793

FCN图像语义分割

FCN全卷网络,舍弃了传统的全毗连层,仅使用卷积和池化等操作完成,end-to-end的像素集预测网络。
网络特点:

  • 不含全毗连层(fc)的全卷积(fully conv)网络,可顺应任意尺寸输入。
  • 增大数据尺寸的反卷积(deconv)层,能够输出精致的结果。
  • 结合差异深度层结果的跳级(skip)结构,同时确保鲁棒性和准确性。
语义分割

图像语义分割(semantic segmentation)是图像处置惩罚和机器视觉技术中关于图像理解的紧张一环,AI范畴中一个紧张分支,常被应用于人脸识别、物体检测、医学影像、卫星图像分析、自动驾驶感知等范畴。
这个就不消过多赘述了,目标检测是画框,语义分割就是把目标从图内里抠出来。
1. 数据处置惩罚

起首下载数据
  1. from download import download
  2. url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset_fcn8s.tar"
  3. download(url, "./dataset", kind="tar", replace=True)
复制代码
数据预处置惩罚
  1. import numpy as np
  2. import cv2
  3. import mindspore.dataset as ds
  4. class SegDataset:
  5.     def __init__(self,
  6.                  image_mean,
  7.                  image_std,
  8.                  data_file='',
  9.                  batch_size=32,
  10.                  crop_size=512,
  11.                  max_scale=2.0,
  12.                  min_scale=0.5,
  13.                  ignore_label=255,
  14.                  num_classes=21,
  15.                  num_readers=2,
  16.                  num_parallel_calls=4):
  17.         self.data_file = data_file
  18.         self.batch_size = batch_size
  19.         self.crop_size = crop_size
  20.         self.image_mean = np.array(image_mean, dtype=np.float32)
  21.         self.image_std = np.array(image_std, dtype=np.float32)
  22.         self.max_scale = max_scale
  23.         self.min_scale = min_scale
  24.         self.ignore_label = ignore_label
  25.         self.num_classes = num_classes
  26.         self.num_readers = num_readers
  27.         self.num_parallel_calls = num_parallel_calls
  28.         max_scale > min_scale
  29. # 初始化数据集的参数,如文件路径、批次大小、裁剪大小、缩放比例、忽略标签、类别数、读取线程数和并行调用数。
  30.     def preprocess_dataset(self, image, label):
  31.         image_out = cv2.imdecode(np.frombuffer(image, dtype=np.uint8), cv2.IMREAD_COLOR)
  32.         label_out = cv2.imdecode(np.frombuffer(label, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
  33.         sc = np.random.uniform(self.min_scale, self.max_scale)
  34.         new_h, new_w = int(sc * image_out.shape[0]), int(sc * image_out.shape[1])
  35.         image_out = cv2.resize(image_out, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
  36.         label_out = cv2.resize(label_out, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
  37.         image_out = (image_out - self.image_mean) / self.image_std
  38.         out_h, out_w = max(new_h, self.crop_size), max(new_w, self.crop_size)
  39.         pad_h, pad_w = out_h - new_h, out_w - new_w
  40.         if pad_h > 0 or pad_w > 0:
  41.             image_out = cv2.copyMakeBorder(image_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
  42.             label_out = cv2.copyMakeBorder(label_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.ignore_label)
  43.         offset_h = np.random.randint(0, out_h - self.crop_size + 1)
  44.         offset_w = np.random.randint(0, out_w - self.crop_size + 1)
  45.         image_out = image_out[offset_h: offset_h + self.crop_size, offset_w: offset_w + self.crop_size, :]
  46.         label_out = label_out[offset_h: offset_h + self.crop_size, offset_w: offset_w+self.crop_size]
  47.         if np.random.uniform(0.0, 1.0) > 0.5:
  48.             image_out = image_out[:, ::-1, :]
  49.             label_out = label_out[:, ::-1]
  50.         image_out = image_out.transpose((2, 0, 1))
  51.         image_out = image_out.copy()
  52.         label_out = label_out.copy()
  53.         label_out = label_out.astype("int32")
  54.         return image_out, label_out
  55. # 随机缩放图像和标签。对图像进行标准化(减去均值,除以标准差)。对图像和标签进行裁剪和翻转等数据增强操作。将图像转换为适合深度学习模型的格式(CHW)。
  56.     def get_dataset(self):
  57.         ds.config.set_numa_enable(True)
  58.         dataset = ds.MindDataset(self.data_file, columns_list=["data", "label"],
  59.                                  shuffle=True, num_parallel_workers=self.num_readers)
  60.         transforms_list = self.preprocess_dataset
  61.         dataset = dataset.map(operations=transforms_list, input_columns=["data", "label"],
  62.                               output_columns=["data", "label"],
  63.                               num_parallel_workers=self.num_parallel_calls)
  64.         dataset = dataset.shuffle(buffer_size=self.batch_size * 10)
  65.         dataset = dataset.batch(self.batch_size, drop_remainder=True)
  66.         return dataset
  67. # 使用 MindDataset 从 self.data_file 读取数据,指定需要的列("data" 和 "label")以及读取线程数。将 preprocess_dataset 方法应用于数据集。打乱数据集,并按指定的批次大小进行批处理。
  68. # 定义创建数据集的参数
  69. IMAGE_MEAN = [103.53, 116.28, 123.675]
  70. IMAGE_STD = [57.375, 57.120, 58.395]
  71. DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"
  72. # 定义模型训练参数
  73. train_batch_size = 4
  74. crop_size = 512
  75. min_scale = 0.5
  76. max_scale = 2.0
  77. ignore_label = 255
  78. num_classes = 21
  79. # 实例化Dataset
  80. dataset = SegDataset(image_mean=IMAGE_MEAN,
  81.                      image_std=IMAGE_STD,
  82.                      data_file=DATA_FILE,
  83.                      batch_size=train_batch_size,
  84.                      crop_size=crop_size,
  85.                      max_scale=max_scale,
  86.                      min_scale=min_scale,
  87.                      ignore_label=ignore_label,
  88.                      num_classes=num_classes,
  89.                      num_readers=2,
  90.                      num_parallel_calls=4)
  91. dataset = dataset.get_dataset()
复制代码
2. 构建网络

  1. import mindspore.nn as nn
  2. class FCN8s(nn.Cell):
  3.     def __init__(self, n_class):
  4.         super().__init__()
  5.         self.n_class = n_class
  6.         self.conv1 = nn.SequentialCell(
  7.             nn.Conv2d(in_channels=3, out_channels=64,
  8.                       kernel_size=3, weight_init='xavier_uniform'),
  9.             nn.BatchNorm2d(64),
  10.             nn.ReLU(),
  11.             nn.Conv2d(in_channels=64, out_channels=64,
  12.                       kernel_size=3, weight_init='xavier_uniform'),
  13.             nn.BatchNorm2d(64),
  14.             nn.ReLU()
  15.         )
  16.         self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
  17.         self.conv2 = nn.SequentialCell(
  18.             nn.Conv2d(in_channels=64, out_channels=128,
  19.                       kernel_size=3, weight_init='xavier_uniform'),
  20.             nn.BatchNorm2d(128),
  21.             nn.ReLU(),
  22.             nn.Conv2d(in_channels=128, out_channels=128,
  23.                       kernel_size=3, weight_init='xavier_uniform'),
  24.             nn.BatchNorm2d(128),
  25.             nn.ReLU()
  26.         )
  27.         self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
  28.         self.conv3 = nn.SequentialCell(
  29.             nn.Conv2d(in_channels=128, out_channels=256,
  30.                       kernel_size=3, weight_init='xavier_uniform'),
  31.             nn.BatchNorm2d(256),
  32.             nn.ReLU(),
  33.             nn.Conv2d(in_channels=256, out_channels=256,
  34.                       kernel_size=3, weight_init='xavier_uniform'),
  35.             nn.BatchNorm2d(256),
  36.             nn.ReLU(),
  37.             nn.Conv2d(in_channels=256, out_channels=256,
  38.                       kernel_size=3, weight_init='xavier_uniform'),
  39.             nn.BatchNorm2d(256),
  40.             nn.ReLU()
  41.         )
  42.         self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
  43.         self.conv4 = nn.SequentialCell(
  44.             nn.Conv2d(in_channels=256, out_channels=512,
  45.                       kernel_size=3, weight_init='xavier_uniform'),
  46.             nn.BatchNorm2d(512),
  47.             nn.ReLU(),
  48.             nn.Conv2d(in_channels=512, out_channels=512,
  49.                       kernel_size=3, weight_init='xavier_uniform'),
  50.             nn.BatchNorm2d(512),
  51.             nn.ReLU(),
  52.             nn.Conv2d(in_channels=512, out_channels=512,
  53.                       kernel_size=3, weight_init='xavier_uniform'),
  54.             nn.BatchNorm2d(512),
  55.             nn.ReLU()
  56.         )
  57.         self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
  58.         self.conv5 = nn.SequentialCell(
  59.             nn.Conv2d(in_channels=512, out_channels=512,
  60.                       kernel_size=3, weight_init='xavier_uniform'),
  61.             nn.BatchNorm2d(512),
  62.             nn.ReLU(),
  63.             nn.Conv2d(in_channels=512, out_channels=512,
  64.                       kernel_size=3, weight_init='xavier_uniform'),
  65.             nn.BatchNorm2d(512),
  66.             nn.ReLU(),
  67.             nn.Conv2d(in_channels=512, out_channels=512,
  68.                       kernel_size=3, weight_init='xavier_uniform'),
  69.             nn.BatchNorm2d(512),
  70.             nn.ReLU()
  71.         )
  72.         self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
  73.         self.conv6 = nn.SequentialCell(
  74.             nn.Conv2d(in_channels=512, out_channels=4096,
  75.                       kernel_size=7, weight_init='xavier_uniform'),
  76.             nn.BatchNorm2d(4096),
  77.             nn.ReLU(),
  78.         )
  79.         self.conv7 = nn.SequentialCell(
  80.             nn.Conv2d(in_channels=4096, out_channels=4096,
  81.                       kernel_size=1, weight_init='xavier_uniform'),
  82.             nn.BatchNorm2d(4096),
  83.             nn.ReLU(),
  84.         )
  85.         self.score_fr = nn.Conv2d(in_channels=4096, out_channels=self.n_class,
  86.                                   kernel_size=1, weight_init='xavier_uniform')
  87.         self.upscore2 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
  88.                                            kernel_size=4, stride=2, weight_init='xavier_uniform')
  89.         self.score_pool4 = nn.Conv2d(in_channels=512, out_channels=self.n_class,
  90.                                      kernel_size=1, weight_init='xavier_uniform')
  91.         self.upscore_pool4 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
  92.                                                 kernel_size=4, stride=2, weight_init='xavier_uniform')
  93.         self.score_pool3 = nn.Conv2d(in_channels=256, out_channels=self.n_class,
  94.                                      kernel_size=1, weight_init='xavier_uniform')
  95.         self.upscore8 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
  96.                                            kernel_size=16, stride=8, weight_init='xavier_uniform')
  97.     def construct(self, x):
  98.         x1 = self.conv1(x)
  99.         p1 = self.pool1(x1)
  100.         x2 = self.conv2(p1)
  101.         p2 = self.pool2(x2)
  102.         x3 = self.conv3(p2)
  103.         p3 = self.pool3(x3)
  104.         x4 = self.conv4(p3)
  105.         p4 = self.pool4(x4)
  106.         x5 = self.conv5(p4)
  107.         p5 = self.pool5(x5)
  108.         x6 = self.conv6(p5)
  109.         x7 = self.conv7(x6)
  110.         sf = self.score_fr(x7)
  111.         u2 = self.upscore2(sf)
  112.         s4 = self.score_pool4(p4)
  113.         f4 = s4 + u2
  114.         u4 = self.upscore_pool4(f4)
  115.         s3 = self.score_pool3(p3)
  116.         f3 = s3 + u4
  117.         out = self.upscore8(f3)
  118.         return out
复制代码
3. 训练准备

下载预训练VGG-16模型
  1. from download import download
  2. from mindspore import load_checkpoint, load_param_into_net
  3. url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/fcn8s_vgg16_pretrain.ckpt"
  4. download(url, "fcn8s_vgg16_pretrain.ckpt", replace=True)
  5. def load_vgg16():
  6.     ckpt_vgg16 = "fcn8s_vgg16_pretrain.ckpt"
  7.     param_vgg = load_checkpoint(ckpt_vgg16)
  8.     load_param_into_net(net, param_vgg)
复制代码
评估函数定义
  1. import numpy as np
  2. import mindspore as ms
  3. import mindspore.nn as nn
  4. import mindspore.train as train
  5. class PixelAccuracy(train.Metric):
  6.     def __init__(self, num_class=21):
  7.         super(PixelAccuracy, self).__init__()
  8.         self.num_class = num_class
  9.     def _generate_matrix(self, gt_image, pre_image):
  10.         mask = (gt_image >= 0) & (gt_image < self.num_class)
  11.         label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
  12.         count = np.bincount(label, minlength=self.num_class**2)
  13.         confusion_matrix = count.reshape(self.num_class, self.num_class)
  14.         return confusion_matrix
  15.     def clear(self):
  16.         self.confusion_matrix = np.zeros((self.num_class,) * 2)
  17.     def update(self, *inputs):
  18.         y_pred = inputs[0].asnumpy().argmax(axis=1)
  19.         y = inputs[1].asnumpy().reshape(4, 512, 512)
  20.         self.confusion_matrix += self._generate_matrix(y, y_pred)
  21.     def eval(self):
  22.         pixel_accuracy = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
  23.         return pixel_accuracy
  24. class PixelAccuracyClass(train.Metric):
  25.     def __init__(self, num_class=21):
  26.         super(PixelAccuracyClass, self).__init__()
  27.         self.num_class = num_class
  28.     def _generate_matrix(self, gt_image, pre_image):
  29.         mask = (gt_image >= 0) & (gt_image < self.num_class)
  30.         label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
  31.         count = np.bincount(label, minlength=self.num_class**2)
  32.         confusion_matrix = count.reshape(self.num_class, self.num_class)
  33.         return confusion_matrix
  34.     def update(self, *inputs):
  35.         y_pred = inputs[0].asnumpy().argmax(axis=1)
  36.         y = inputs[1].asnumpy().reshape(4, 512, 512)
  37.         self.confusion_matrix += self._generate_matrix(y, y_pred)
  38.     def clear(self):
  39.         self.confusion_matrix = np.zeros((self.num_class,) * 2)
  40.     def eval(self):
  41.         mean_pixel_accuracy = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)
  42.         mean_pixel_accuracy = np.nanmean(mean_pixel_accuracy)
  43.         return mean_pixel_accuracy
  44. class MeanIntersectionOverUnion(train.Metric):
  45.     def __init__(self, num_class=21):
  46.         super(MeanIntersectionOverUnion, self).__init__()
  47.         self.num_class = num_class
  48.     def _generate_matrix(self, gt_image, pre_image):
  49.         mask = (gt_image >= 0) & (gt_image < self.num_class)
  50.         label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
  51.         count = np.bincount(label, minlength=self.num_class**2)
  52.         confusion_matrix = count.reshape(self.num_class, self.num_class)
  53.         return confusion_matrix
  54.     def update(self, *inputs):
  55.         y_pred = inputs[0].asnumpy().argmax(axis=1)
  56.         y = inputs[1].asnumpy().reshape(4, 512, 512)
  57.         self.confusion_matrix += self._generate_matrix(y, y_pred)
  58.     def clear(self):
  59.         self.confusion_matrix = np.zeros((self.num_class,) * 2)
  60.     def eval(self):
  61.         mean_iou = np.diag(self.confusion_matrix) / (
  62.             np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
  63.             np.diag(self.confusion_matrix))
  64.         mean_iou = np.nanmean(mean_iou)
  65.         return mean_iou
  66. class FrequencyWeightedIntersectionOverUnion(train.Metric):
  67.     def __init__(self, num_class=21):
  68.         super(FrequencyWeightedIntersectionOverUnion, self).__init__()
  69.         self.num_class = num_class
  70.     def _generate_matrix(self, gt_image, pre_image):
  71.         mask = (gt_image >= 0) & (gt_image < self.num_class)
  72.         label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
  73.         count = np.bincount(label, minlength=self.num_class**2)
  74.         confusion_matrix = count.reshape(self.num_class, self.num_class)
  75.         return confusion_matrix
  76.     def update(self, *inputs):
  77.         y_pred = inputs[0].asnumpy().argmax(axis=1)
  78.         y = inputs[1].asnumpy().reshape(4, 512, 512)
  79.         self.confusion_matrix += self._generate_matrix(y, y_pred)
  80.     def clear(self):
  81.         self.confusion_matrix = np.zeros((self.num_class,) * 2)
  82.     def eval(self):
  83.         freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
  84.         iu = np.diag(self.confusion_matrix) / (
  85.             np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
  86.             np.diag(self.confusion_matrix))
  87.         frequency_weighted_iou = (freq[freq > 0] * iu[freq > 0]).sum()
  88.         return frequency_weighted_iou
复制代码
模型训练
  1. import mindspore
  2. from mindspore import Tensor
  3. import mindspore.nn as nn
  4. from mindspore.train import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor, Model
  5. device_target = "Ascend"
  6. mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target=device_target)
  7. train_batch_size = 4
  8. num_classes = 21
  9. # 初始化模型结构
  10. net = FCN8s(n_class=21)
  11. # 导入vgg16预训练参数
  12. load_vgg16()
  13. # 计算学习率
  14. min_lr = 0.0005
  15. base_lr = 0.05
  16. train_epochs = 1
  17. iters_per_epoch = dataset.get_dataset_size()
  18. total_step = iters_per_epoch * train_epochs
  19. lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,
  20.                                             base_lr,
  21.                                             total_step,
  22.                                             iters_per_epoch,
  23.                                             decay_epoch=2)
  24. lr = Tensor(lr_scheduler[-1])
  25. # 定义损失函数
  26. loss = nn.CrossEntropyLoss(ignore_index=255)
  27. # 定义优化器
  28. optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001)
  29. # 定义loss_scale
  30. scale_factor = 4
  31. scale_window = 3000
  32. loss_scale_manager = ms.amp.DynamicLossScaleManager(scale_factor, scale_window)
  33. # 初始化模型
  34. if device_target == "Ascend":
  35.     model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
  36. else:
  37.     model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
  38. # 设置ckpt文件保存的参数
  39. time_callback = TimeMonitor(data_size=iters_per_epoch)
  40. loss_callback = LossMonitor()
  41. callbacks = [time_callback, loss_callback]
  42. save_steps = 330
  43. keep_checkpoint_max = 5
  44. config_ckpt = CheckpointConfig(save_checkpoint_steps=10,
  45.                                keep_checkpoint_max=keep_checkpoint_max)
  46. ckpt_callback = ModelCheckpoint(prefix="FCN8s",
  47.                                 directory="./ckpt",
  48.                                 config=config_ckpt)
  49. callbacks.append(ckpt_callback)
  50. model.train(train_epochs, dataset, callbacks=callbacks)
复制代码
4. 模型评估

  1. IMAGE_MEAN = [103.53, 116.28, 123.675]
  2. IMAGE_STD = [57.375, 57.120, 58.395]
  3. DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"
  4. # 下载已训练好的权重文件
  5. url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/FCN8s.ckpt"
  6. download(url, "FCN8s.ckpt", replace=True)
  7. net = FCN8s(n_class=num_classes)
  8. ckpt_file = "FCN8s.ckpt"
  9. param_dict = load_checkpoint(ckpt_file)
  10. load_param_into_net(net, param_dict)
  11. if device_target == "Ascend":
  12.     model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
  13. else:
  14.     model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
  15. # 实例化Dataset
  16. dataset = SegDataset(image_mean=IMAGE_MEAN,
  17.                      image_std=IMAGE_STD,
  18.                      data_file=DATA_FILE,
  19.                      batch_size=train_batch_size,
  20.                      crop_size=crop_size,
  21.                      max_scale=max_scale,
  22.                      min_scale=min_scale,
  23.                      ignore_label=ignore_label,
  24.                      num_classes=num_classes,
  25.                      num_readers=2,
  26.                      num_parallel_calls=4)
  27. dataset_eval = dataset.get_dataset()
  28. model.eval(dataset_eval)
复制代码
5. 模型推理

  1. import cv2
  2. import matplotlib.pyplot as plt
  3. net = FCN8s(n_class=num_classes)
  4. # 设置超参
  5. ckpt_file = "FCN8s.ckpt"
  6. param_dict = load_checkpoint(ckpt_file)
  7. load_param_into_net(net, param_dict)
  8. eval_batch_size = 4
  9. img_lst = []
  10. mask_lst = []
  11. res_lst = []
  12. # 推理效果展示(上方为输入图片,下方为推理效果图片)
  13. plt.figure(figsize=(8, 5))
  14. show_data = next(dataset_eval.create_dict_iterator())
  15. show_images = show_data["data"].asnumpy()
  16. mask_images = show_data["label"].reshape([4, 512, 512])
  17. show_images = np.clip(show_images, 0, 1)
  18. for i in range(eval_batch_size):
  19.     img_lst.append(show_images[i])
  20.     mask_lst.append(mask_images[i])
  21. res = net(show_data["data"]).asnumpy().argmax(axis=1)
  22. for i in range(eval_batch_size):
  23.     plt.subplot(2, 4, i + 1)
  24.     plt.imshow(img_lst[i].transpose(1, 2, 0))
  25.     plt.axis("off")
  26.     plt.subplots_adjust(wspace=0.05, hspace=0.02)
  27.     plt.subplot(2, 4, i + 5)
  28.     plt.imshow(res[i])
  29.     plt.axis("off")
  30.     plt.subplots_adjust(wspace=0.05, hspace=0.02)
  31. plt.show()
复制代码
打卡:


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

忿忿的泥巴坨

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表