华为开源自研AI框架昇思MindSpore应用案例:ICNet用于及时的语义分割 ...

打印 上一主题 下一主题

主题 1018|帖子 1018|积分 3054

 ICNet用于及时的语义分割
  ICNet 被广泛应用于及时的语义分割范畴。它在处理图像数据时,能够以较高的服从进行语义分割操作,为相关范畴的研究和实际应用提供了有力的支持。ICNet 的及时性使其在浩繁场景中都具有很大的上风,例如在视频处理、自动驾驶等对及时性要求较高的范畴,ICNet 能够快速准确地对图像进行语义分割,为后续的决策和处理提供关键信息。
    
   如果你对MindSpore感爱好,可以关注昇思MindSpore社区
​​​
​​​
一、环境准备

1.进入ModelArts官网

云平台资助用户快速创建和摆设模子,管理全周期AI工作流,选择下面的云平台以开始使用昇思MindSpore,获取安装下令,安装MindSpore2.0.0-alpha版本,可以在昇思教程中进入ModelArts官网
​​​
选择下方CodeLab立刻体验
​​​
等待环境搭建完成
​​​
2.使用CodeLab体验Notebook实例

选择ModelArts Upload Files上传Git文件,地址为GitHub - yfjcode/ICNet: mindspore icnet model
​​​


选择Kernel环境
​​​
切换至GPU环境,切换成第一个限时免费
​​​
进入昇思MindSpore官网,点击上方的安装
​​​
获取安装下令
​​​
回到Notebook中,在第一块代码前加入下令
​​​
  
  1. conda update -n base -c defaults conda
复制代码
​​​
安装MindSpore 2.0 GPU版本
  
  1. conda install mindspore=2.0.0a0 -c mindspore -c conda-forge
复制代码
​​​
安装mindvision
  
  1. pip install mindvision
复制代码
​​​
安装下载download
  
  1. pip install download
复制代码
​​​

二、应用体验


1.模子准备

根据原作者提示
   环境准备与数据读取 本案例基于MindSpore-CPU版本实现,在CPU上完成模子训练。
  案例实现所使用的数据:Cityscape Dataset Website
  为了下载数据集,我们起首需要在Cityscapes数据集官网进行注册,而且最好使用edu教诲邮箱进行注册,此后等待几天,就可以下载数据集了,这里我们下载了两个文件:gtFine_trainvaltest.zip和leftImg8bit_trainvaltest.zip (11GB)。
  下载完成后,我们对数据集压缩文件进行解压,文件的目录布局如下所示。
  由于我们是在CPU上跑得,本来数据集有1个多G,全部拿来跑得话,很轻易掉卡,故我们就选择一个都会的一些图片完成。
  起首要处理数据,生成对应的.mindrecord 和 .mindrecord.db文件
  需要注意的是,在生成这两个文件之前,我们要创建一个文件夹,用cityscapes_mindrecord命名,放在cityscapes文件夹的同级目录下: 而且要保持cityscapes_mindrecord文件夹内里为空
  下面是构建数据集的代码:注意,要保持cityscapes_mindrecord文件夹内里为空,报错可能是文件夹已经有文件了,文件夹地址为:/home/ma-user/work/ICNet/data/cityscapes_mindrecord
  需要删掉/data/cityscapes_mindrecord文件

删掉文件后,需要修改路径,删掉/home/ma-user/work/ICNet,用./替换,之后直接运行代码块即可
  1. """Prepare Cityscapes dataset"""
  2. import os
  3. import random
  4. import argparse
  5. import numpy as np
  6. from PIL import Image
  7. from PIL import ImageOps
  8. from PIL import ImageFilter
  9. import mindspore.dataset as de
  10. from mindspore.mindrecord import FileWriter
  11. import mindspore.dataset.vision as transforms
  12. import mindspore.dataset.transforms as tc
  13. def _get_city_pairs(folder, split='train'):
  14.     """Return two path arrays of data set img and mask"""
  15.     def get_path_pairs(image_folder, masks_folder):
  16.         image_paths = []
  17.         masks_paths = []
  18.         for root, _, files in os.walk(image_folder):
  19.             for filename in files:
  20.                 if filename.endswith('.png'):
  21.                     imgpath = os.path.join(root, filename)
  22.                     foldername = os.path.basename(os.path.dirname(imgpath))
  23.                     maskname = filename.replace('leftImg8bit', 'gtFine_labelIds')
  24.                     maskpath = os.path.join(masks_folder, foldername, maskname)
  25.                     if os.path.isfile(imgpath) and os.path.isfile(maskpath):
  26.                         image_paths.append(imgpath)
  27.                         masks_paths.append(maskpath)
  28.                     else:
  29.                         print('cannot find the mask or image:', imgpath, maskpath)
  30.         print('Found {} images in the folder {}'.format(len(image_paths), image_folder))
  31.         return image_paths, masks_paths
  32.     if split in ('train', 'val'):
  33.         # "./Cityscapes/leftImg8bit/train" or "./Cityscapes/leftImg8bit/val"
  34.         img_folder = os.path.join(folder, 'leftImg8bit/' + split)
  35.         # "./Cityscapes/gtFine/train" or "./Cityscapes/gtFine/val"
  36.         mask_folder = os.path.join(folder, 'gtFine/' + split)
  37.         # The order of img_paths and mask_paths is one-to-one correspondence
  38.         img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
  39.         return img_paths, mask_paths
  40. def _sync_transform(img, mask):
  41.     """img and mask augmentation"""
  42.     a = random.Random()
  43.     a.seed(1234)
  44.     base_size = 1024
  45.     crop_size = 960
  46.     # random mirror
  47.     if random.random() < 0.5:
  48.         img = img.transpose(Image.FLIP_LEFT_RIGHT)
  49.         mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
  50.     crop_size = crop_size
  51.     # random scale (short edge)
  52.     short_size = random.randint(int(base_size * 0.5), int(base_size * 2.0))
  53.     w, h = img.size
  54.     if h > w:
  55.         ow = short_size
  56.         oh = int(1.0 * h * ow / w)
  57.     else:
  58.         oh = short_size
  59.         ow = int(1.0 * w * oh / h)
  60.     img = img.resize((ow, oh), Image.BILINEAR)
  61.     mask = mask.resize((ow, oh), Image.NEAREST)
  62.     # pad crop
  63.     if short_size < crop_size:
  64.         padh = crop_size - oh if oh < crop_size else 0
  65.         padw = crop_size - ow if ow < crop_size else 0
  66.         img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
  67.         mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
  68.     # random crop crop_size
  69.     w, h = img.size
  70.     x1 = random.randint(0, w - crop_size)
  71.     y1 = random.randint(0, h - crop_size)
  72.     img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
  73.     mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size))
  74.     # gaussian blur as in PSP
  75.     if random.random() < 0.5:
  76.         img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
  77.     # final transform
  78.     output = _img_mask_transform(img, mask)
  79.     return output
  80. def _class_to_index(mask):
  81.     """class to index"""
  82.     # Reference:
  83.     # https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py
  84.     _key = np.array([-1, -1, -1, -1, -1, -1,
  85.                      -1, -1, 0, 1, -1, -1,
  86.                      2, 3, 4, -1, -1, -1,
  87.                      5, -1, 6, 7, 8, 9,
  88.                      10, 11, 12, 13, 14, 15,
  89.                      -1, -1, 16, 17, 18])
  90.     # [-1, ..., 33]
  91.     _mapping = np.array(range(-1, len(_key) - 1)).astype('int32')
  92.     # assert the value
  93.     values = np.unique(mask)
  94.     for value in values:
  95.         assert value in _mapping
  96.     # Get the index of each pixel value in the mask corresponding to _mapping
  97.     index = np.digitize(mask.ravel(), _mapping, right=True)
  98.     # According to the above index, according to _key, get the corresponding
  99.     return _key[index].reshape(mask.shape)
  100. def _img_transform(img):
  101.     return np.array(img)
  102. def _mask_transform(mask):
  103.     target = _class_to_index(np.array(mask).astype('int32'))
  104.     return np.array(target).astype('int32')
  105. def _img_mask_transform(img, mask):
  106.     """img and mask transform"""
  107.     input_transform = tc.Compose([
  108.         transforms.ToTensor(),
  109.         transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), is_hwc=False)])
  110.     img = _img_transform(img)
  111.     mask = _mask_transform(mask)
  112.     img = input_transform(img)
  113.     img = np.array(img).astype(np.float32)
  114.     mask = np.array(mask).astype(np.float32)
  115.     return (img, mask)
  116. def data_to_mindrecord_img(prefix='cityscapes-2975.mindrecord', file_num=1,
  117.                            root='./', split='train', mindrecord_dir="./"):
  118.     """to mindrecord"""
  119.     mindrecord_path = os.path.join(mindrecord_dir, prefix)
  120.     writter = FileWriter(mindrecord_path, file_num)
  121.     img_paths, mask_paths = _get_city_pairs(root, split)
  122.     cityscapes_json = {
  123.         "images": {"type": "int32", "shape": [1024, 2048, 3]},
  124.         "mask": {"type": "int32", "shape": [1024, 2048]},
  125.     }
  126.     writter.add_schema(cityscapes_json, "cityscapes_json")
  127.     images_files_num = len(img_paths)
  128.     for index in range(images_files_num):
  129.         img = Image.open(img_paths[index]).convert('RGB')
  130.         img = np.array(img, dtype=np.int32)
  131.         mask = Image.open(mask_paths[index])
  132.         mask = np.array(mask, dtype=np.int32)
  133.         row = {"images": img, "mask": mask}
  134.         # print("images",img, "mask", mask)
  135.         # print("images_files_num,index, img_paths[index],mask_paths[index]",images_files_num,index,img_paths[index],mask_paths[index])
  136.         if (index + 1) % 10 == 0:
  137.             print("writing {}/{} into mindrecord".format(index + 1, images_files_num))
  138.         writter.write_raw_data([row])
  139.     writter.commit()
  140. def get_Image_crop_nor(img, mask):
  141.     image = np.uint8(img)
  142.     mask = np.uint8(mask)
  143.     image = Image.fromarray(image)
  144.     mask = Image.fromarray(mask)
  145.     output = _sync_transform(image, mask)
  146.     return output
  147. def create_icnet_dataset(mindrecord_file, batch_size=16, device_num=1, rank_id=0):
  148.     """create dataset for training"""
  149.     a = random.Random()
  150.     a.seed(1234)
  151.     ds = de.MindDataset(mindrecord_file, columns_list=["images", "mask"],
  152.                         num_shards=device_num, shard_id=rank_id, shuffle=True)
  153.     ds = ds.map(operations=get_Image_crop_nor, input_columns=["images", "mask"], output_columns=["image", "masks"])
  154.     ds = ds.batch(batch_size=batch_size, drop_remainder=False)
  155.     return ds
  156. dataset_path="./data/cityscapes/"
  157. mindrecord_path="./data/cityscapes_mindrecord/"
  158. data_to_mindrecord_img(root=dataset_path, mindrecord_dir=mindrecord_path)
  159. # if __name__ == '__main__':
  160. #     parser = argparse.ArgumentParser(description="dataset_to_mindrecord")
  161. #     parser.add_argument("--dataset_path", type=str, default="/home/ma-user/work/ICNet/data/cityscapes/", help="dataset path")
  162. #     parser.add_argument("--mindrecord_path", type=str, default="/home/ma-user/work/ICNet/data/cityscapes_mindrecord/",
  163. #                         help="mindrecord_path")
  164. #     args_opt = parser.parse_args()
  165. #     data_to_mindrecord_img(root=args_opt.dataset_path, mindrecord_dir=args_opt.mindrecord_path)
复制代码
可以看到已经生成的对应的数据集文件,然后我们创建稍后用到的数据


注意修改路径
  1. prefix = 'cityscapes-2975.mindrecord'
  2. train_mindrecord_dir="/home/ma-user/work/ICNet/data/cityscapes_mindrecord"
  3. train_train_batch_size_percard=4
  4. device_num=1
  5. rank_id=0
  6. mindrecord_dir = train_mindrecord_dir
  7. mindrecord_file = os.path.join(mindrecord_dir, prefix)
  8. print("mindrecord_file",mindrecord_file)
  9. # print("cfg['train'][‘’train_batch_size_percard‘]",cfg['train']["train_batch_size_percard"])
  10. dataset = create_icnet_dataset(mindrecord_file, batch_size=train_train_batch_size_percard,
  11.                                    device_num=device_num, rank_id=rank_id)
  12. print(dataset)
复制代码

2.模子构建

创建需要训练模子的一些参数:(这里只是展示,不运行,具体参数运行在背面)

   1.Model
  model: name: "icnet" backbone: "resnet50v1" base_size: 1024 # during augmentation, shorter size will be resized between [base_size0.5, base_size2.0] crop_size: 960 # end of augmentation, crop to training
  2.Optimizer
  optimizer: init_lr: 0.02 momentum: 0.9 weight_decay: 0.0001
  3.Training
  train: train_batch_size_percard: 4 valid_batch_size: 1 cityscapes_root: "/data/cityscapes/" epochs: 10 val_epoch: 1 # run validation every val-epoch ckpt_dir: "./ckpt/" # ckpt and training log will be saved here mindrecord_dir: '/home/ma-user/work/ICNet/data/cityscapes_mindrecord' pretrained_model_path: '/home/ma-user/work/ICNet/root/cacheckpt/resnet50-icnet-150_2.ckpt' save_checkpoint_epochs: 5 keep_checkpoint_max: 10
  4.Valid
  test: ckpt_path: "" # set the pretrained model path correctly
  ​注意修改路径
  1. train_epochs=10
  2. train_data_size = dataset.get_dataset_size()
  3. print("data_size", train_data_size)
  4. epoch = train_epochs
  5. project_path="/home/ma-user/work/ICNet/"
  6. train_pretrained_model_path="/home/ma-user/work/ICNet/root/cacheckpt/resnet50-icnet-150_2.ckpt"
复制代码

  1. import mindspore as ms
  2. import mindspore.nn as nn
  3. import mindspore.ops as ops
  4. from src.loss import ICNetLoss
  5. from src.models.resnet50_v1 import get_resnet50v1b
  6. __all__ = ['ICNetdc']
  7. class ICNetdc(nn.Cell):
  8.     """Image Cascade Network"""
  9.     def __init__(self, nclass=19, pretrained_path="", istraining=True, norm_layer=nn.SyncBatchNorm):
  10.         super(ICNetdc, self).__init__()
  11.         self.conv_sub1 = nn.SequentialCell(
  12.             _ConvBNReLU(3, 32, 3, 2, norm_layer=norm_layer),
  13.             _ConvBNReLU(32, 32, 3, 2, norm_layer=norm_layer),
  14.             _ConvBNReLU(32, 64, 3, 2, norm_layer=norm_layer)
  15.         )
  16.         self.istraining = istraining
  17.         self.ppm = PyramidPoolingModule()
  18.         self.backbone = SegBaseModel(root=pretrained_path, istraining=istraining)
  19.         self.head = _ICHead(nclass, norm_layer=norm_layer)
  20.         self.loss = ICNetLoss()
  21.         self.resize_bilinear = nn.ResizeBilinear()
  22.         self.__setattr__('exclusive', ['conv_sub1', 'head'])
  23.     def construct(self, x, y):
  24.         """ICNet_construct"""
  25.         if x.shape[0] != 1:
  26.             x = x.squeeze()
  27.         # sub 1
  28.         x_sub1 = self.conv_sub1(x)
  29.         h, w = x.shape[2:]
  30.         # sub 2
  31.         x_sub2 = self.resize_bilinear(x, size=(h / 2, w / 2))
  32.         _, x_sub2, _, _ = self.backbone(x_sub2)
  33.         # sub 4
  34.         _, _, _, x_sub4 = self.backbone(x)
  35.         # add PyramidPoolingModule
  36.         x_sub4 = self.ppm(x_sub4)
  37.         output = self.head(x_sub1, x_sub2, x_sub4)
  38.         if self.istraining:
  39.             outputs = self.loss(output, y)
  40.         else:
  41.             outputs = output
  42.         return outputs
  43. class PyramidPoolingModule(nn.Cell):
  44.     """PPM"""
  45.     def __init__(self, pyramids=None):
  46.         super(PyramidPoolingModule, self).__init__()
  47.         self.avgpool = ops.ReduceMean(keep_dims=True)
  48.         self.pool2 = nn.AvgPool2d(kernel_size=15, stride=15)
  49.         self.pool3 = nn.AvgPool2d(kernel_size=10, stride=10)
  50.         self.pool6 = nn.AvgPool2d(kernel_size=5, stride=5)
  51.         self.resize_bilinear = nn.ResizeBilinear()
  52.     def construct(self, x):
  53.         """ppm_construct"""
  54.         feat = x
  55.         height, width = x.shape[2:]
  56.         x1 = self.avgpool(x, (2, 3))
  57.         x1 = self.resize_bilinear(x1, size=(height, width), align_corners=True)
  58.         feat = feat + x1
  59.         x2 = self.pool2(x)
  60.         x2 = self.resize_bilinear(x2, size=(height, width), align_corners=True)
  61.         feat = feat + x2
  62.         x3 = self.pool3(x)
  63.         x3 = self.resize_bilinear(x3, size=(height, width), align_corners=True)
  64.         feat = feat + x3
  65.         x6 = self.pool6(x)
  66.         x6 = self.resize_bilinear(x6, size=(height, width), align_corners=True)
  67.         feat = feat + x6
  68.         return feat
  69. class _ICHead(nn.Cell):
  70.     """Head"""
  71.     def __init__(self, nclass, norm_layer=nn.SyncBatchNorm, **kwargs):
  72.         super(_ICHead, self).__init__()
  73.         self.cff_12 = CascadeFeatureFusion12(128, 64, 128, nclass, norm_layer, **kwargs)
  74.         self.cff_24 = CascadeFeatureFusion24(2048, 512, 128, nclass, norm_layer, **kwargs)
  75.         self.conv_cls = nn.Conv2d(128, nclass, 1, has_bias=False)
  76.         self.outputs = list()
  77.         self.resize_bilinear = nn.ResizeBilinear()
  78.     def construct(self, x_sub1, x_sub2, x_sub4):
  79.         """Head_construct"""
  80.         outputs = self.outputs
  81.         x_cff_24, x_24_cls = self.cff_24(x_sub4, x_sub2)
  82.         x_cff_12, x_12_cls = self.cff_12(x_cff_24, x_sub1)
  83.         h1, w1 = x_cff_12.shape[2:]
  84.         up_x2 = self.resize_bilinear(x_cff_12, size=(h1 * 2, w1 * 2),
  85.                                      align_corners=True)
  86.         up_x2 = self.conv_cls(up_x2)
  87.         h2, w2 = up_x2.shape[2:]
  88.         up_x8 = self.resize_bilinear(up_x2, size=(h2 * 4, w2 * 4),
  89.                                      align_corners=True)  # scale_factor=4,
  90.         outputs.append(up_x8)
  91.         outputs.append(up_x2)
  92.         outputs.append(x_12_cls)
  93.         outputs.append(x_24_cls)
  94.         return outputs
  95. class _ConvBNReLU(nn.Cell):
  96.     """ConvBNRelu"""
  97.     def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, dilation=1,
  98.                  groups=1, norm_layer=nn.SyncBatchNorm, bias=False, **kwargs):
  99.         super(_ConvBNReLU, self).__init__()
  100.         self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='pad', padding=padding,
  101.                               dilation=dilation,
  102.                               group=1, has_bias=False)
  103.         self.bn = norm_layer(out_channels, momentum=0.1)
  104.         self.relu = nn.ReLU()
  105.     def construct(self, x):
  106.         x = self.conv(x)
  107.         x = self.bn(x)
  108.         x = self.relu(x)
  109.         return x
  110. class CascadeFeatureFusion12(nn.Cell):
  111.     """CFF Unit"""
  112.     def __init__(self, low_channels, high_channels, out_channels, nclass, norm_layer=nn.SyncBatchNorm, **kwargs):
  113.         super(CascadeFeatureFusion12, self).__init__()
  114.         self.conv_low = nn.SequentialCell(
  115.             nn.Conv2d(low_channels, out_channels, 3, pad_mode='pad', padding=2, dilation=2, has_bias=False),
  116.             norm_layer(out_channels, momentum=0.1)
  117.         )
  118.         self.conv_high = nn.SequentialCell(
  119.             nn.Conv2d(high_channels, out_channels, kernel_size=1, has_bias=False),
  120.             norm_layer(out_channels, momentum=0.1)
  121.         )
  122.         self.conv_low_cls = nn.Conv2d(in_channels=out_channels, out_channels=nclass, kernel_size=1, has_bias=False)
  123.         self.resize_bilinear = nn.ResizeBilinear()
  124.         self.scalar_cast = ops.ScalarCast()
  125.         self.relu = ms.nn.ReLU()
  126.     def construct(self, x_low, x_high):
  127.         """cff_construct"""
  128.         h, w = x_high.shape[2:]
  129.         x_low = self.resize_bilinear(x_low, size=(h, w), align_corners=True)
  130.         x_low = self.conv_low(x_low)
  131.         x_high = self.conv_high(x_high)
  132.         x = x_low + x_high
  133.         x = self.relu(x)
  134.         x_low_cls = self.conv_low_cls(x_low)
  135.         return x, x_low_cls
  136. class CascadeFeatureFusion24(nn.Cell):
  137.     """CFF Unit"""
  138.     def __init__(self, low_channels, high_channels, out_channels, nclass, norm_layer=nn.SyncBatchNorm, **kwargs):
  139.         super(CascadeFeatureFusion24, self).__init__()
  140.         self.conv_low = nn.SequentialCell(
  141.             nn.Conv2d(low_channels, out_channels, 3, pad_mode='pad', padding=2, dilation=2, has_bias=False),
  142.             norm_layer(out_channels, momentum=0.1)
  143.         )
  144.         self.conv_high = nn.SequentialCell(
  145.             nn.Conv2d(high_channels, out_channels, kernel_size=1, has_bias=False),
  146.             norm_layer(out_channels, momentum=0.1)
  147.         )
  148.         self.conv_low_cls = nn.Conv2d(in_channels=out_channels, out_channels=nclass, kernel_size=1, has_bias=False)
  149.         self.resize_bilinear = nn.ResizeBilinear()
  150.         self.relu = ms.nn.ReLU()
  151.     def construct(self, x_low, x_high):
  152.         """ccf_construct"""
  153.         h, w = x_high.shape[2:]
  154.         x_low = self.resize_bilinear(x_low, size=(h, w), align_corners=True)
  155.         x_low = self.conv_low(x_low)
  156.         x_high = self.conv_high(x_high)
  157.         x = x_low + x_high
  158.         x = self.relu(x)
  159.         x_low_cls = self.conv_low_cls(x_low)
  160.         return x, x_low_cls
  161. class SegBaseModel(nn.Cell):
  162.     """Base Model for Semantic Segmentation"""
  163.     def __init__(self, nclass=19, backbone='resnet50', root="", istraining=False):
  164.         super(SegBaseModel, self).__init__()
  165.         self.nclass = nclass
  166.         if backbone == 'resnet50':
  167.             self.pretrained = get_resnet50v1b(ckpt_root=root, istraining=istraining)
  168.     def construct(self, x):
  169.         """forwarding pre-trained network"""
  170.         x = self.pretrained.conv1(x)
  171.         x = self.pretrained.bn1(x)
  172.         x = self.pretrained.relu(x)
  173.         x = self.pretrained.maxpool(x)
  174.         c1 = self.pretrained.layer1(x)
  175.         c2 = self.pretrained.layer2(c1)
  176.         c3 = self.pretrained.layer3(c2)
  177.         c4 = self.pretrained.layer4(c3)
  178.         return c1, c2, c3, c4
复制代码

  1. def poly_lr(base_lr, decay_steps, total_steps, end_lr=0.0001, power=0.9):
  2.     for i in range(total_steps):
  3.         step_ = min(i, decay_steps)
  4.         yield (base_lr - end_lr) * ((1.0 - step_ / decay_steps) ** power) + end_lr
复制代码


  1. optimizer_init_lr=0.02
  2. optimizer_weight_decay = 0.0001
  3. optimizer_momentum= 0.9
  4. train_save_checkpoint_epochs=5
  5. train_keep_checkpoint_max = 10
  6. rank_id = 0
  7. device_id = 0
  8. device_num =1
  9. # from src.lr_scheduler import poly_lr
  10. import os
  11. import sys
  12. import logging
  13. import argparse
  14. # import yaml
  15. import mindspore.nn as nn
  16. from mindspore import Model
  17. from mindspore import context
  18. from mindspore import set_seed
  19. from mindspore.context import ParallelMode
  20. from mindspore.communication import init
  21. from mindspore.train.callback import CheckpointConfig
  22. from mindspore.train.callback import ModelCheckpoint
  23. from mindspore.train.callback import LossMonitor
  24. from mindspore.train.callback import TimeMonitor
  25. iters_per_epoch = train_data_size
  26. total_train_steps = iters_per_epoch * epoch
  27. base_lr = optimizer_init_lr
  28. iter_lr = poly_lr(base_lr, total_train_steps, total_train_steps, end_lr=0.0, power=0.9)
  29. network = ICNetdc(pretrained_path=train_pretrained_model_path, norm_layer=nn.BatchNorm2d)
  30. optim = nn.SGD(params=network.trainable_params(), learning_rate=iter_lr, momentum=optimizer_momentum,
  31.                    weight_decay=optimizer_weight_decay)
  32. model = Model(network, optimizer=optim, metrics=None)
  33. config_ck_train = CheckpointConfig(save_checkpoint_steps=iters_per_epoch * train_save_checkpoint_epochs,
  34.                                        keep_checkpoint_max=train_keep_checkpoint_max)
  35. ckpoint_cb_train = ModelCheckpoint(prefix='ICNet', directory=project_path + 'ckpt' + str(device_id),
  36.                                        config=config_ck_train)
  37. time_cb_train = TimeMonitor(data_size=dataset.get_dataset_size())
  38. loss_cb_train = LossMonitor()
  39. print("train begins------------------------------")
  40. model.train(epoch=epoch, train_dataset=dataset, callbacks=[ckpoint_cb_train, loss_cb_train, time_cb_train],
  41.                 dataset_sink_mode=True)
复制代码
3.模子验证

  1. import os
  2. import time
  3. import sys
  4. import argparse
  5. import yaml
  6. import numpy as np
  7. from PIL import Image
  8. import mindspore.ops as ops
  9. from mindspore import load_param_into_net
  10. from mindspore import load_checkpoint
  11. from mindspore import Tensor
  12. import mindspore.dataset.vision as vision
  13. from src.models import ICNet
  14. from src.metric import SegmentationMetric
  15. from src.logger import SetupLogger
  16. class Evaluator:
  17.     """evaluate"""
  18.     def __init__(self):
  19.         # self.cfg = config
  20.         # get valid dataset images and targets
  21.         self.image_paths, self.mask_paths = _get_city_pairs(dataset_path, "val")
  22.         # self.image_paths,
  23.         # self.mask_paths
  24.         
  25.         # create network
  26.         # self.model = ICNetdc(nclass=19, pretrained_path=train_pretrained_model_path, norm_layer=nn.BatchNorm2d,istraining=False)
  27.         self.model = ICNet(nclass=19, pretrained_path=train_pretrained_model_path, istraining=False)
  28.         # load ckpt
  29.         checkpoint_path="/home/ma-user/work/ICNet/ckpt0/ICNet-10_1.ckpt"
  30.         ckpt_file_name = checkpoint_path
  31.         param_dict = load_checkpoint(ckpt_file_name)
  32.         load_param_into_net(self.model, param_dict)
  33.         # evaluation metrics
  34.         self.metric = SegmentationMetric(19)
  35.     def eval(self):
  36.         """evaluate"""
  37.         self.metric.reset()
  38.         model = self.model
  39.         model = model.set_train(False)
  40.         logger.info("Start validation, Total sample: {:d}".format(len(self.image_paths)))
  41.         list_time = []
  42.         for i in range(len(self.image_paths)):
  43.             image = Image.open(self.image_paths[i]).convert('RGB')  # image shape: (W,H,3)
  44.             mask = Image.open(self.mask_paths[i])  # mask shape: (W,H)
  45.             image = self._img_transform(image)  # image shape: (3,H,W) [0,1]
  46.             mask = self._mask_transform(mask)  # mask shape: (H,w)
  47.             image = Tensor(image)
  48.             expand_dims = ops.ExpandDims()
  49.             image = expand_dims(image, 0)
  50.             start_time = time.time()
  51.             output = model(image)
  52.             end_time = time.time()
  53.             step_time = end_time - start_time
  54.             output = output.asnumpy()
  55.             mask = np.expand_dims(mask.asnumpy(), axis=0)
  56.             self.metric.update(output, mask)
  57.             list_time.append(step_time)
  58.         mIoU, pixAcc = self.metric.get()
  59.         average_time = sum(list_time) / len(list_time)
  60.         print("avgmiou", mIoU)
  61.         print("avg_pixacc", pixAcc)
  62.         print("avgtime", average_time)
  63.     def _img_transform(self, image):
  64.         """img_transform"""
  65.         to_tensor = vision.ToTensor()
  66.         normalize = vision.Normalize([.485, .456, .406], [.229, .224, .225], is_hwc=False)
  67.         image = to_tensor(image)
  68.         image = normalize(image)
  69.         return image
  70.     def _mask_transform(self, mask):
  71.         mask = self._class_to_index(np.array(mask).astype('int32'))
  72.         return Tensor(np.array(mask).astype('int32'))  # torch.LongTensor
  73.     def _class_to_index(self, mask):
  74.         """assert the value"""
  75.         values = np.unique(mask)
  76.         self._key = np.array([-1, -1, -1, -1, -1, -1,
  77.                               -1, -1, 0, 1, -1, -1,
  78.                               2, 3, 4, -1, -1, -1,
  79.                               5, -1, 6, 7, 8, 9,
  80.                               10, 11, 12, 13, 14, 15,
  81.                               -1, -1, 16, 17, 18])
  82.         self._mapping = np.array(range(-1, len(self._key) - 1)).astype('int32')
  83.         for value in values:
  84.             assert value in self._mapping
  85.         # Get the index of each pixel value in the mask corresponding to _mapping
  86.         index = np.digitize(mask.ravel(), self._mapping, right=True)
  87.         # According to the above index index, according to _key, the corresponding mask image is obtained
  88.         return self._key[index].reshape(mask.shape)
  89. def _get_city_pairs(folder, split='train'):
  90.     """get dataset img_mask_path_pairs"""
  91.     def get_path_pairs(image_folder, mask_folder):
  92.         img_paths = []
  93.         mask_paths = []
  94.         for root, _, files in os.walk(image_folder):
  95.             for filename in files:
  96.                 if filename.endswith('.png'):
  97.                     imgpath = os.path.join(root, filename)
  98.                     foldername = os.path.basename(os.path.dirname(imgpath))
  99.                     maskname = filename.replace('leftImg8bit', 'gtFine_labelIds')
  100.                     maskpath = os.path.join(mask_folder, foldername, maskname)
  101.                     if os.path.isfile(imgpath) and os.path.isfile(maskpath):
  102.                         img_paths.append(imgpath)
  103.                         mask_paths.append(maskpath)
  104.                     else:
  105.                         print('cannot find the mask or image:', imgpath, maskpath)
  106.         print('Found {} images in the folder {}'.format(len(img_paths), image_folder))
  107.         return img_paths, mask_paths
  108.     if split in ('train', 'val', 'test'):
  109.         # "./Cityscapes/leftImg8bit/train" or "./Cityscapes/leftImg8bit/val"
  110.         img_folder = os.path.join(folder, 'leftImg8bit/' + split)
  111.         # "./Cityscapes/gtFine/train" or "./Cityscapes/gtFine/val"
  112.         mask_folder = os.path.join(folder, 'gtFine/' + split)
  113.         img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
  114.         return img_paths, mask_paths
复制代码
  1. train_ckpt_dir="./ckpt/"
  2. model_name="icnet"
  3. model_backbone="resnet50v1"
  4. checkpoint_path="./ckpt0/ICNet-10_1.ckpt"
  5. logger = SetupLogger(name="semantic_segmentation",
  6.                          save_dir=train_ckpt_dir,
  7.                          distributed_rank=0,
  8.                          filename='{}_{}_evaluate_log.txt'.format(model_name,model_backbone))
  9. evaluator = Evaluator()
  10. evaluator.eval()
复制代码
最后根据路径的图片获取语义分割文本






免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

没腿的鸟

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表