【LUT技术专题】SPFLUT代码解读

[复制链接]
发表于 2025-9-8 03:03:01 | 显示全部楼层 |阅读模式
目录
原文概要
1. 训练
2. 压缩并转表
3. 微调
4. 测试



本文是对SPFLUT技术的代码解读,原文解读请看SPFLUT。 
原文概要

SPFLUT方法重点在于对角线优先压缩策略,该方法总体流程分为4个部门,训练、转换(内里包含了压缩)、微调、测试。其代码的总体结构如下:

可以看到流程与MULUT基本同等,只不过在第二步转换之前还有一步对LUT进行压缩的过程,即2_compress_lut_from_net.py文件。另外第三步的微调中也有针对压缩后的LUT进行微调的代码


1. 训练

这里我们可以从sr/model.py中,获取到SPF_LUT_net模型的代码实现如下:
  1. class SPF_LUT_net(nn.Module):
  2.     def __init__(self, nf=32, scale=4, modes=['s', 'd', 'y'], stages=2):
  3.         super(SPF_LUT_net, self).__init__()
  4.         self.upscale = scale
  5.         self.modes = modes
  6.         self.convblock1 = ConvBlock(1, 2, scale=None, output_quant=False, modes=modes, nf=nf)
  7.         self.convblock2 = ConvBlock(1, 2, scale=None, output_quant=False, modes=modes, nf=nf)
  8.         self.convblock3 = ConvBlock(1, 2, scale=None, output_quant=False, modes=modes, nf=nf)
  9.         self.convblock4 = ConvBlock(1, 1, scale=None, output_quant=False, modes=modes, nf=nf)
  10.         self.ChannelConv = MuLUTcUnit(in_c=4, out_c=4, mode='1x1', nf=nf)
  11.         self.upblock = ConvBlock(4, 1, scale=scale, output_quant=False, modes=modes, nf=nf)
  12.     def forward(self, x, phase='train'):
  13.         B, C, H, W = x.size()
  14.         x = x.reshape((B * C, 1, H, W))
  15.         refine_list = []
  16.         # block1
  17.         x = self.convblock1(x)
  18.         avg_factor, bias, norm = len(self.modes) * 4, 127, 255.0
  19.         x = round_func(torch.clamp((x / avg_factor) + bias, 0, 255)) / norm
  20.         refine_list.append(x[:, 0:1, :, :])
  21.         x = x[:, 1:, :, :]
  22.         # block2
  23.         x = self.convblock2(x)
  24.         avg_factor, bias, norm = len(self.modes) * 4, 127, 255.0
  25.         x = round_func(torch.clamp((x / avg_factor) + bias, 0, 255)) / norm
  26.         refine_list.append(x[:, 0:1, :, :])
  27.         x = x[:, 1:, :, :]
  28.         # block3
  29.         x = self.convblock3(x)
  30.         avg_factor, bias, norm = len(self.modes) * 4, 127, 255.0
  31.         x = round_func(torch.clamp((x / avg_factor) + bias, 0, 255)) / norm
  32.         refine_list.append(x[:, 0:1, :, :])
  33.         x = x[:, 1:, :, :]
  34.         # block4
  35.         x = self.convblock4(x)
  36.         avg_factor, bias, norm = len(self.modes) * 4, 127, 255.0
  37.         x = round_func(torch.clamp((x / avg_factor) + bias, 0, 255)) / norm
  38.         refine_list.append(x)
  39.         x = torch.cat(refine_list, dim=1)
  40.         x = round_func(torch.tanh(self.ChannelConv(x)) * 127.0)
  41.         x = round_func(torch.clamp(x + 127, 0, 255)) / 255.0
  42.         x = self.upblock(x)
  43.         avg_factor, bias, norm = len(self.modes), 0, 1
  44.         x = round_func((x / avg_factor) + bias)
  45.         if phase == 'train':
  46.             x = x / 255.0
  47.         x = x.reshape((B, C, self.upscale * H, self.upscale * W))
  48.         return x
复制代码
通过上述代码可以看出,SPFLUT模型主要由两个子模块构成,一个是ConvBlock,另一个是MuLUTcUnit,此中ConvBlock的实现如下:
  1. class ConvBlock(nn.Module):
  2.     def __init__(self, in_c, out_c, scale=None, output_quant=False, modes=['s', 'd', 'y'], nf=64):
  3.         super(ConvBlock, self).__init__()
  4.         self.in_c = in_c
  5.         self.out_c = out_c
  6.         self.modes = modes
  7.         self.module_dict = dict()
  8.         self.upscale = scale
  9.         self.output_quant = output_quant
  10.         scale_factor = 1 if scale is None else scale ** 2
  11.         for c in range(in_c):
  12.             for mode in modes:
  13.                 self.module_dict['DepthwiseBlock{}_{}'.format(c, mode)] = MuLUTConv('{}x{}'.format(mode.upper(), 'N'),
  14.                                                                                     nf=nf, out_c=out_c * scale_factor,
  15.                                                                                     stride=1)
  16.         self.module_dict = nn.ModuleDict(self.module_dict)
  17.         if scale is None:
  18.             self.pixel_shuffle = identity
  19.         else:
  20.             self.pixel_shuffle = nn.PixelShuffle(scale)
  21.     def forward(self, x):
  22.         modes = self.modes
  23.         x_out = 0
  24.         for c in range(self.in_c):
  25.             x_c = x[:, c:c + 1, :, :]
  26.             pred = 0
  27.             for mode in modes:
  28.                 pad = mode_pad_dict[mode]
  29.                 sub_module = self.module_dict['DepthwiseBlock{}_{}'.format(c, mode)]
  30.                 for r in [0, 1, 2, 3]:
  31.                     pred += round_func(torch.tanh(torch.rot90(self.pixel_shuffle(
  32.                         sub_module(F.pad(torch.rot90(x_c, r, [2, 3]), (0, pad, 0, pad), mode='replicate'))),
  33.                         (4 - r) % 4, [2, 3])) * 127)
  34.             x_out += pred
  35.         if self.output_quant:
  36.             avg_factor = len(modes) * 4 * self.in_c
  37.             x = round_func(torch.clamp(x_out / avg_factor, -1, 1) * 127) / 127
  38.         else:
  39.             x = x_out / self.in_c
  40.         return x
复制代码
也是由MuLUTConv构成的,位于common/network.py中,而这个模块我们在MuLUT论文代码讲解中有提到,是一个由3种不同范例S、D、Y的kernel组成的一个RF=3x3的模块,这里还需要旋转和clamp等操作,防止每层的效果溢出。
而MuLUTcUnit即是通道上的MuLUT模块,位于common/network.py中,由于只在通道上操作,因此kernel_size上是1,主要建立起特性通道之间的关联。
整体结构是比力清晰的,尤其是对MuLUT的子模块熟悉的情况下,同样的,不清晰的读者可以初始化一个模型来渐渐推理tensor的shape来熟悉。

2. 压缩并转表

这部门代码位于2_compress_lut_from_net.py中,整体流程如下:
  1. def compress_SPFLUT(opt):
  2.     def save_SPFLUT_DFC(x, lut_path, module):
  3.         # Split input to not over GPU memory
  4.         B = x.size(0) // 100
  5.         outputs = []
  6.         # Extract input-output pairs
  7.         with torch.no_grad():
  8.             model_G.eval()
  9.             for b in range(100):
  10.                 if b == 99:
  11.                     batch_input = x[b * B:]
  12.                 else:
  13.                     batch_input = x[b * B:(b + 1) * B]
  14.                 batch_output = module(batch_input)
  15.                 results = torch.round(torch.tanh(batch_output) * 127).cpu().data.numpy().astype(np.int8)
  16.                 outputs += [results]
  17.         results = np.concatenate(outputs, 0)
  18.         results = results.reshape(x.size(0), -1)
  19.         np.save(lut_path, results)
  20.         print("Resulting LUT size: ", results.shape, "Saved to", lut_path)
  21.     modes = [i for i in opt.modes]
  22.     stages = opt.stages
  23.     model = getattr(Model, 'SPF_LUT_net')
  24.     model_G = model(nf=opt.nf, scale=opt.scale, modes=modes, stages=stages).cuda()
  25.     lm = torch.load(os.path.join(opt.expDir, 'Model_{:06d}.pth'.format(opt.loadIter)))
  26.     model_G.load_state_dict(lm, strict=True)
  27.     input_tensor = get_input_tensor(opt)
  28.     for mode in modes:
  29.         if opt.cd == 'xyzt':
  30.             input_tensor_c1 = compress_lut_xyzt(opt, input_tensor)
  31.         elif opt.cd == 'xyz':
  32.             input_tensor_c1 = compress_lut_xyz(opt, input_tensor)
  33.         elif opt.cd == 'xy':
  34.             input_tensor_c1 = compress_lut(opt, input_tensor)
  35.         else:
  36.             raise ValueError
  37.         input_tensor_c2 = compress_lut_larger_interval(opt, input_tensor)
  38.         if mode != 's':
  39.             input_tensor_c1 = get_mode_input_tensor(input_tensor_c1, mode)
  40.             input_tensor_c2 = get_mode_input_tensor(input_tensor_c2, mode)
  41.         # conv1
  42.         module = model_G.convblock1.module_dict['DepthwiseBlock{}_{}'.format(0, mode)]
  43.         lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress1.npy'.format(opt.lutName, 1, mode))
  44.         save_SPFLUT_DFC(input_tensor_c1, lut_path, module)
  45.         lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress2.npy'.format(opt.lutName, 1, mode))
  46.         save_SPFLUT_DFC(input_tensor_c2, lut_path, module)
  47.         # conv2
  48.         module = model_G.convblock2.module_dict['DepthwiseBlock{}_{}'.format(0, mode)]
  49.         lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress1.npy'.format(opt.lutName, 2, mode))
  50.         save_SPFLUT_DFC(input_tensor_c1, lut_path, module)
  51.         lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress2.npy'.format(opt.lutName, 2, mode))
  52.         save_SPFLUT_DFC(input_tensor_c2, lut_path, module)
  53.         # conv3
  54.         module = model_G.convblock3.module_dict['DepthwiseBlock{}_{}'.format(0, mode)]
  55.         lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress1.npy'.format(opt.lutName, 3, mode))
  56.         save_SPFLUT_DFC(input_tensor_c1, lut_path, module)
  57.         lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress2.npy'.format(opt.lutName, 3, mode))
  58.         save_SPFLUT_DFC(input_tensor_c2, lut_path, module)
  59.         # conv4
  60.         module = model_G.convblock4.module_dict['DepthwiseBlock{}_{}'.format(0, mode)]
  61.         lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress1.npy'.format(opt.lutName, 4, mode))
  62.         save_SPFLUT_DFC(input_tensor_c1, lut_path, module)
  63.         lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress2.npy'.format(opt.lutName, 4, mode))
  64.         save_SPFLUT_DFC(input_tensor_c2, lut_path, module)
  65.         # conv6
  66.         for c in range(4):
  67.             module = model_G.upblock.module_dict['DepthwiseBlock{}_{}'.format(c, mode)]
  68.             lut_path = os.path.join(opt.expDir, '{}_s{}c{}_{}_compress1.npy'.format(opt.lutName, 6,c, mode))
  69.             save_SPFLUT_DFC(input_tensor_c1, lut_path, module)
  70.             lut_path = os.path.join(opt.expDir, '{}_s{}c{}_{}_compress2.npy'.format(opt.lutName, 6,c, mode))
  71.             save_SPFLUT_DFC(input_tensor_c2, lut_path, module)
  72.     # conv5
  73.     input_tensor = input_tensor.reshape((-1,4,1,1))
  74.     module = model_G.ChannelConv
  75.     lut_path = os.path.join(opt.expDir, '{}_s{}_channel.npy'.format(opt.lutName, 5))
  76.     save_SPFLUT_DFC(input_tensor, lut_path, module)
复制代码
这里需要关注的细节是对角线压缩相干的3个函数:compress_lut_xyzt、compress_lut_xyz、compress_lut,对应于4维、3维和2维的压缩过程,以及非对角线压缩相干的函数compress_lut_larger_interval,最后我们可以发现对于通道的卷积conv5,作者是没有进行压缩的,由于通道conv不满意对角线先验,故不能进行对角线优先的压缩
针对于对角线相干的函数:以2维压缩为例,跟我们之前的讲解是一样的。
  1. def compress_lut(opt, input_tensor):
  2.     base = torch.arange(0, 257, 2 ** opt.interval)  # 0-256
  3.     base[-1] -= 1
  4.     L = base.size(0)
  5.     d = opt.dw
  6.     diag = 2 * d + 1
  7.     N = diag * L + (1 - diag ** 2) // 4
  8.     input_tensor = input_tensor.reshape(L * L, L, L, 1, 2, 2)
  9.     index_i = torch.zeros((N,)).type(torch.int64)
  10.     index_j = torch.zeros((N,)).type(torch.int64)
  11.     cnt = 0
  12.     ref2index = np.zeros((L, diag), dtype=np.int_) - 1
  13.     for i in range(L):
  14.         for j in range(L):
  15.             if abs(i - j) <= d:
  16.                 index_i[cnt] = i
  17.                 index_j[cnt] = j
  18.                 ref2index[i, j - i] = cnt
  19.                 cnt += 1
  20.     np.save(os.path.join(opt.expDir, 'ref2index_{}{}i{}.npy'.format(opt.cd, opt.dw, opt.si)),ref2index)
  21.     index_compress = index_i * L + index_j
  22.     compressed_input_tensor = input_tensor[index_compress, ...].reshape(-1, 1, 2, 2)
  23.     return compressed_input_tensor
复制代码
作者是通过改变input_tensor来实现这个过程,我们需要取到2维tensor,满意对角线距离条件的所有位置,那这里opt.dw(变量d)对应于我们前面讲解中提到的
,满意条件的将其放入ref2index中,并使得cnt加1,这样我们可以将对角线的位置进行保存。
至于L,是我们前面不停在用的与隔断interval相干的个数,一样平常即是17(4bit采样)。而N是我们前面推理算过的索引的总个数
(各人可以带入diag来盘算N,这样可以跟公式完全对应),至此2维的一个输入tensor就全部对应完毕,送入模型盘算就可以了,这样子把对角线的位置进行了优先保存。
针对于非对角线的位置:看compress_lut_larger_interval函数,实现如下。
  1. def compress_lut_larger_interval(opt, input_tensor):
  2.     base = torch.arange(0, 257, 2 ** opt.interval)  # 0-256
  3.     base[-1] -= 1
  4.     L = base.size(0)
  5.     input_tensor = input_tensor.reshape(L, L, L, L, 1, 2, 2)
  6.     if opt.si==5:
  7.         k = 2
  8.     elif opt.si==6:
  9.         k = 4
  10.     elif opt.si==7:
  11.         k = 8
  12.     else:
  13.         raise ValueError
  14.     compressed_input_tensor = input_tensor[::k, ::k, ::k, ::k, ...].reshape(-1, 1, 2, 2)
  15.     return compressed_input_tensor
复制代码
比力简朴,即选用一个更大的比例,由于我们前面已经使用了4bit来做隔断,那么当opt.si为5时,我们需要对当前的input_tensor做2隔断的采样就可以,之后都是同理可得。
针对于通道:那我们已经讲到了通道是不可以进行压缩的,因此它的input_tensor是不变的,跟之前一样,实现如下,这个过程我们是比力熟悉的,(假如不停有看LUT系列的文章。还不了解的可以关注一下LUT专题哦):
  1. def get_input_tensor(opt):
  2.     # 1D input
  3.     base = torch.arange(0, 257, 2 ** opt.interval)  # 0-256
  4.     base[-1] -= 1
  5.     L = base.size(0)
  6.     # 2D input
  7.     # 256*256   0 0 0...    |1 1 1...     |...|255 255 255...
  8.     first = base.cuda().unsqueeze(1).repeat(1, L).reshape(-1)
  9.     # 256*256   0 1 2 .. 255|0 1 2 ... 255|...|0 1 2 ... 255
  10.     second = base.cuda().repeat(L)
  11.     onebytwo = torch.stack([first, second], 1)  # [256*256, 2]
  12.     # 3D input
  13.     # 256*256*256   0 x65536|1 x65536|...|255 x65536
  14.     third = base.cuda().unsqueeze(1).repeat(1, L * L).reshape(-1)
  15.     onebytwo = onebytwo.repeat(L, 1)
  16.     onebythree = torch.cat(
  17.         [third.unsqueeze(1), onebytwo], 1)  # [256*256*256, 3]
  18.     # 4D input
  19.     fourth = base.cuda().unsqueeze(1).repeat(1, L * L * L).reshape(
  20.         -1)  # 256*256*256*256   0 x16777216|1 x16777216|...|255 x16777216
  21.     onebythree = onebythree.repeat(L, 1)
  22.     # [256*256*256*256, 4]
  23.     onebyfourth = torch.cat([fourth.unsqueeze(1), onebythree], 1)
  24.     # Rearange input: [N, 4] -> [N, C=1, H=2, W=2]
  25.     input_tensor = onebyfourth.unsqueeze(1).unsqueeze(
  26.         1).reshape(-1, 1, 2, 2).float() / 255.0
  27.     return input_tensor
复制代码

3. 微调

微调的部门实在跟MuLUT对比,无明显变革,主要还是看作者如何构建SPF_LUT模型,位置在sr/model.py中,代码如下:
  1. class SPF_LUT(nn.Module):
  2.     """ PyTorch version of MuLUT for LUT-aware fine-tuning. """
  3.     def __init__(self, lut_folder, stages, modes, lutName, upscale, interval, phase=None, **kwargs):
  4.         super(SPF_LUT, self).__init__()
  5.         self.interval = interval
  6.         self.upscale = upscale
  7.         self.modes = modes
  8.         self.stages = stages
  9.         L = 2 ** (8 - interval) + 1
  10.         for mode in modes:
  11.             # conv1
  12.             lut_path = os.path.join(lut_folder, '{}_s{}c0_{}.npy'.format(lutName, 1, mode))
  13.             # lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c0_{}.npy'.format(1, mode))
  14.             key = "s{}c0_{}".format(1, mode)
  15.             lut_arr = np.load(lut_path).reshape((-1, 2)).astype(np.float32) / 127.0
  16.             self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))
  17.             # conv2
  18.             lut_path = os.path.join(lut_folder, '{}_s{}c0_{}.npy'.format(lutName, 2, mode))
  19.             # lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c0_{}.npy'.format(2, mode))
  20.             key = "s{}c0_{}".format(2, mode)
  21.             lut_arr = np.load(lut_path).reshape((-1, 2)).astype(np.float32) / 127.0
  22.             self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))
  23.             # conv3
  24.             lut_path = os.path.join(lut_folder, '{}_s{}c0_{}.npy'.format(lutName, 3, mode))
  25.             # lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c0_{}.npy'.format(3, mode))
  26.             key = "s{}c0_{}".format(3, mode)
  27.             lut_arr = np.load(lut_path).reshape((-1, 2)).astype(np.float32) / 127.0
  28.             self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))
  29.             # conv4
  30.             lut_path = os.path.join(lut_folder, '{}_s{}c0_{}.npy'.format(lutName, 4, mode))
  31.             # lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c0_{}.npy'.format(4, mode))
  32.             key = "s{}c0_{}".format(4, mode)
  33.             lut_arr = np.load(lut_path).reshape((-1, 1)).astype(np.float32) / 127.0
  34.             self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))
  35.             for c in range(4):
  36.                 # conv6
  37.                 lut_path = os.path.join(lut_folder, '{}_s{}c{}_{}.npy'.format(lutName, 6,c, mode))
  38.                 # lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c{}_{}.npy'.format(6,c, mode))
  39.                 key = "s{}c{}_{}".format(6,c, mode)
  40.                 lut_arr = np.load(lut_path).reshape((-1, self.upscale * self.upscale)).astype(np.float32) / 127.0
  41.                 self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))
  42.         # conv5
  43.         lut_path = os.path.join(lut_folder, '{}_s{}_channel.npy'.format(lutName, 5))
  44.         # lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}_channel.npy'.format(5))
  45.         key = "s{}_channel".format(5)
  46.         lut_arr = np.load(lut_path).reshape((-1, 4)).astype(np.float32) / 127.0
  47.         self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))
复制代码
你会发现,实在跟MuLUT一样,将LUT给register成可训练的parameter,这样子去做一个微调。


4. 测试

测试的部门由于我们的LUT做了改变,修改为了对角线和非对角线,因此在最后的查表推理的部门需要做一些改变,以对角线做2维压缩为例,在sr/4_test_SPF_LUT_DFC.py中。
  1. def InterpTorchBatch_compress_xy(weight, img_in, h, w, interval, rot, d, upscale=4, out_c=1, mode='s',ref2index=None):
  2.     q = 2 ** interval  # 16
  3.     L = 2 ** (8 - interval) + 1  # 17
  4.     diag = 2 * d + 1
  5.     N = diag * L + (1 - diag ** 2) // 4
  6.     if mode == "s":
  7.         img_x = img_in[:, :, 0:0 + h, 0:0 + w]
  8.         img_y = img_in[:, :, 0:0 + h, 1:1 + w]
  9.         index_flag = (np.abs(img_x - img_y) <= d * q)
  10.         # Extract MSBs
  11.         img_a1 = img_in[:, :, 0:0 + h, 0:0 + w] // q
  12.         img_b1 = img_in[:, :, 0:0 + h, 1:1 + w] // q
  13.         img_c1 = img_in[:, :, 1:1 + h, 0:0 + w] // q
  14.         img_d1 = img_in[:, :, 1:1 + h, 1:1 + w] // q
  15.         # Extract LSBs
  16.         fa = img_in[:, :, 0:0 + h, 0:0 + w] % q
  17.         fb = img_in[:, :, 0:0 + h, 1:1 + w] % q
  18.         fc = img_in[:, :, 1:1 + h, 0:0 + w] % q
  19.         fd = img_in[:, :, 1:1 + h, 1:1 + w] % q
  20.     elif mode == 'd':
  21.         img_x = img_in[:, :, 0:0 + h, 0:0 + w]
  22.         img_y = img_in[:, :, 0:0 + h, 2:2 + w]
  23.         index_flag = (np.abs(img_x - img_y) <= d * q)
  24.         img_a1 = img_in[:, :, 0:0 + h, 0:0 + w] // q
  25.         img_b1 = img_in[:, :, 0:0 + h, 2:2 + w] // q
  26.         img_c1 = img_in[:, :, 2:2 + h, 0:0 + w] // q
  27.         img_d1 = img_in[:, :, 2:2 + h, 2:2 + w] // q
  28.         fa = img_in[:, :, 0:0 + h, 0:0 + w] % q
  29.         fb = img_in[:, :, 0:0 + h, 2:2 + w] % q
  30.         fc = img_in[:, :, 2:2 + h, 0:0 + w] % q
  31.         fd = img_in[:, :, 2:2 + h, 2:2 + w] % q
  32.     elif mode == 'y':
  33.         img_x = img_in[:, :, 0:0 + h, 0:0 + w]
  34.         img_y = img_in[:, :, 1:1 + h, 1:1 + w]
  35.         index_flag = (np.abs(img_x - img_y) <= d * q)
  36.         img_a1 = img_in[:, :, 0:0 + h, 0:0 + w] // q
  37.         img_b1 = img_in[:, :, 1:1 + h, 1:1 + w] // q
  38.         img_c1 = img_in[:, :, 1:1 + h, 2:2 + w] // q
  39.         img_d1 = img_in[:, :, 2:2 + h, 1:1 + w] // q
  40.         fa = img_in[:, :, 0:0 + h, 0:0 + w] % q
  41.         fb = img_in[:, :, 1:1 + h, 1:1 + w] % q
  42.         fc = img_in[:, :, 1:1 + h, 2:2 + w] % q
  43.         fd = img_in[:, :, 2:2 + h, 1:1 + w] % q
  44.     else:
  45.         # more sampling modes can be implemented similarly
  46.         raise ValueError("Mode {} not implemented.".format(mode))
  47.     img_a1 = img_a1[index_flag].flatten().astype(np.int_)
  48.     img_b1 = img_b1[index_flag].flatten().astype(np.int_)
  49.     img_c1 = img_c1[index_flag].flatten().astype(np.int_)
  50.     img_d1 = img_d1[index_flag].flatten().astype(np.int_)
  51.     fa = fa[index_flag].flatten()
  52.     fb = fb[index_flag].flatten()
  53.     fc = fc[index_flag].flatten()
  54.     fd = fd[index_flag].flatten()
  55.     img_a2 = img_a1 + 1
  56.     img_b2 = img_b1 + 1
  57.     img_c2 = img_c1 + 1
  58.     img_d2 = img_d1 + 1
  59.     k00 = ref2index[img_a1, img_b1 - img_a1]
  60.     k01 = ref2index[img_a1, img_b2 - img_a1]
  61.     k10 = ref2index[img_a2, img_b1 - img_a2]
  62.     k11 = ref2index[img_a2, img_b2 - img_a2]
  63.     p0000 = weight[k00,img_c1, img_d1].reshape((-1, out_c,upscale,upscale))
  64.     p0001 = weight[k00,img_c1, img_d2].reshape((-1, out_c,upscale,upscale))
  65.     p0010 = weight[k00,img_c2, img_d1].reshape((-1, out_c,upscale,upscale))
  66.     p0011 = weight[k00,img_c2, img_d2].reshape((-1, out_c,upscale,upscale))
  67.     p0100 = weight[k01,img_c1, img_d1].reshape((-1, out_c,upscale,upscale))
  68.     p0101 = weight[k01,img_c1, img_d2].reshape((-1, out_c,upscale,upscale))
  69.     p0110 = weight[k01,img_c2, img_d1].reshape((-1, out_c,upscale,upscale))
  70.     p0111 = weight[k01,img_c2, img_d2].reshape((-1, out_c,upscale,upscale))
  71.     p1000 = weight[k10,img_c1, img_d1].reshape((-1, out_c,upscale,upscale))
  72.     p1001 = weight[k10,img_c1, img_d2].reshape((-1, out_c,upscale,upscale))
  73.     p1010 = weight[k10,img_c2, img_d1].reshape((-1, out_c,upscale,upscale))
  74.     p1011 = weight[k10,img_c2, img_d2].reshape((-1, out_c,upscale,upscale))
  75.     p1100 = weight[k11,img_c1, img_d1].reshape((-1, out_c,upscale,upscale))
  76.     p1101 = weight[k11,img_c1, img_d2].reshape((-1, out_c,upscale,upscale))
  77.     p1110 = weight[k11,img_c2, img_d1].reshape((-1, out_c,upscale,upscale))
  78.     p1111 = weight[k11,img_c2, img_d2].reshape((-1, out_c,upscale,upscale))
  79.     # Output image holder
  80.     out = np.zeros((img_a1.shape[0],out_c, upscale, upscale))
  81.     sz = img_a1.shape[0]
  82.     out = out.reshape(sz, -1)
  83.     p0000 = p0000.reshape(sz, -1)
  84.     p0100 = p0100.reshape(sz, -1)
  85.     p1000 = p1000.reshape(sz, -1)
  86.     p1100 = p1100.reshape(sz, -1)
  87.     fa = fa.reshape(-1, 1)
  88.     p0001 = p0001.reshape(sz, -1)
  89.     p0101 = p0101.reshape(sz, -1)
  90.     p1001 = p1001.reshape(sz, -1)
  91.     p1101 = p1101.reshape(sz, -1)
  92.     fb = fb.reshape(-1, 1)
  93.     fc = fc.reshape(-1, 1)
  94.     p0010 = p0010.reshape(sz, -1)
  95.     p0110 = p0110.reshape(sz, -1)
  96.     p1010 = p1010.reshape(sz, -1)
  97.     p1110 = p1110.reshape(sz, -1)
  98.     fd = fd.reshape(-1, 1)
  99.     p0011 = p0011.reshape(sz, -1)
  100.     p0111 = p0111.reshape(sz, -1)
  101.     p1011 = p1011.reshape(sz, -1)
  102.     p1111 = p1111.reshape(sz, -1)
  103.     fab = fa > fb;
  104.     fac = fa > fc;
  105.     fad = fa > fd
  106.     fbc = fb > fc;
  107.     fbd = fb > fd;
  108.     fcd = fc > fd
  109.     i1 = i = np.logical_and.reduce((fab, fbc, fcd)).squeeze(1)
  110.     # print(p0000[i].shape,fa[i].shape,i.shape,out_c)
  111.     out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fb[i]) * p1000[i] + (fb[i] - fc[i]) * p1100[i] + (fc[i] - fd[i]) * p1110[
  112.         i] + (fd[i]) * p1111[i]
  113.     i2 = i = np.logical_and.reduce((~i1[:, None], fab, fbc, fbd)).squeeze(1)
  114.     out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fb[i]) * p1000[i] + (fb[i] - fd[i]) * p1100[i] + (fd[i] - fc[i]) * p1101[
  115.         i] + (fc[i]) * p1111[i]
  116.     i3 = i = np.logical_and.reduce((~i1[:, None], ~i2[:, None], fab, fbc, fad)).squeeze(1)
  117.     out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fd[i]) * p1000[i] + (fd[i] - fb[i]) * p1001[i] + (fb[i] - fc[i]) * p1101[
  118.         i] + (fc[i]) * p1111[i]
  119.     i4 = i = np.logical_and.reduce((~i1[:, None], ~i2[:, None], ~i3[:, None], fab, fbc)).squeeze(1)
  120.     out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fa[i]) * p0001[i] + (fa[i] - fb[i]) * p1001[i] + (fb[i] - fc[i]) * p1101[
  121.         i] + (fc[i]) * p1111[i]
  122.     i5 = i = np.logical_and.reduce((~(fbc), fab, fac, fbd)).squeeze(1)
  123.     out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fc[i]) * p1000[i] + (fc[i] - fb[i]) * p1010[i] + (fb[i] - fd[i]) * p1110[
  124.         i] + (fd[i]) * p1111[i]
  125.     i6 = i = np.logical_and.reduce((~(fbc), ~i5[:, None], fab, fac, fcd)).squeeze(1)
  126.     out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fc[i]) * p1000[i] + (fc[i] - fd[i]) * p1010[i] + (fd[i] - fb[i]) * p1011[
  127.         i] + (fb[i]) * p1111[i]
  128.     i7 = i = np.logical_and.reduce((~(fbc), ~i5[:, None], ~i6[:, None], fab, fac, fad)).squeeze(1)
  129.     out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fd[i]) * p1000[i] + (fd[i] - fc[i]) * p1001[i] + (fc[i] - fb[i]) * p1011[
  130.         i] + (fb[i]) * p1111[i]
  131.     i8 = i = np.logical_and.reduce((~(fbc), ~i5[:, None], ~i6[:, None], ~i7[:, None], fab, fac)).squeeze(1)
  132.     out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fa[i]) * p0001[i] + (fa[i] - fc[i]) * p1001[i] + (fc[i] - fb[i]) * p1011[
  133.         i] + (fb[i]) * p1111[i]
  134.     i9 = i = np.logical_and.reduce((~(fbc), ~(fac), fab, fbd)).squeeze(1)
  135.     out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fa[i]) * p0010[i] + (fa[i] - fb[i]) * p1010[i] + (fb[i] - fd[i]) * p1110[
  136.         i] + (fd[i]) * p1111[i]
  137.     # Fix the overflow bug in SR-LUT's implementation, should compare fd with fa first!
  138.     # i10 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:,None], fab, fcd)).squeeze(1)
  139.     # out[i] = (q-fc[i]) * p0000[i] + (fc[i]-fa[i]) * p0010[i] + (fa[i]-fd[i]) * p1010[i] + (fd[i]-fb[i]) * p1011[i] + (fb[i]) * p1111[i]
  140.     # i11 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:,None], ~i10[:,None], fab, fad)).squeeze(1)
  141.     # out[i] = (q-fc[i]) * p0000[i] + (fc[i]-fd[i]) * p0010[i] + (fd[i]-fa[i]) * p0011[i] + (fa[i]-fb[i]) * p1011[i] + (fb[i]) * p1111[i]
  142.     i10 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:, None], fab, fad)).squeeze(1)  # c > a > d > b
  143.     out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fa[i]) * p0010[i] + (fa[i] - fd[i]) * p1010[i] + (fd[i] - fb[i]) * p1011[
  144.         i] + (fb[i]) * p1111[i]
  145.     i11 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:, None], ~i10[:, None], fab, fcd)).squeeze(1)  # c > d > a > b
  146.     out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fd[i]) * p0010[i] + (fd[i] - fa[i]) * p0011[i] + (fa[i] - fb[i]) * p1011[
  147.         i] + (fb[i]) * p1111[i]
  148.     i12 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:, None], ~i10[:, None], ~i11[:, None], fab)).squeeze(1)
  149.     out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fa[i]) * p0011[i] + (fa[i] - fb[i]) * p1011[
  150.         i] + (fb[i]) * p1111[i]
  151.     i13 = i = np.logical_and.reduce((~(fab), fac, fcd)).squeeze(1)
  152.     out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fa[i]) * p0100[i] + (fa[i] - fc[i]) * p1100[i] + (fc[i] - fd[i]) * p1110[
  153.         i] + (fd[i]) * p1111[i]
  154.     i14 = i = np.logical_and.reduce((~(fab), ~i13[:, None], fac, fad)).squeeze(1)
  155.     out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fa[i]) * p0100[i] + (fa[i] - fd[i]) * p1100[i] + (fd[i] - fc[i]) * p1101[
  156.         i] + (fc[i]) * p1111[i]
  157.     i15 = i = np.logical_and.reduce((~(fab), ~i13[:, None], ~i14[:, None], fac, fbd)).squeeze(1)
  158.     out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fd[i]) * p0100[i] + (fd[i] - fa[i]) * p0101[i] + (fa[i] - fc[i]) * p1101[
  159.         i] + (fc[i]) * p1111[i]
  160.     i16 = i = np.logical_and.reduce((~(fab), ~i13[:, None], ~i14[:, None], ~i15[:, None], fac)).squeeze(1)
  161.     out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fb[i]) * p0001[i] + (fb[i] - fa[i]) * p0101[i] + (fa[i] - fc[i]) * p1101[
  162.         i] + (fc[i]) * p1111[i]
  163.     i17 = i = np.logical_and.reduce((~(fab), ~(fac), fbc, fad)).squeeze(1)
  164.     out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fc[i]) * p0100[i] + (fc[i] - fa[i]) * p0110[i] + (fa[i] - fd[i]) * p1110[
  165.         i] + (fd[i]) * p1111[i]
  166.     i18 = i = np.logical_and.reduce((~(fab), ~(fac), ~i17[:, None], fbc, fcd)).squeeze(1)
  167.     out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fc[i]) * p0100[i] + (fc[i] - fd[i]) * p0110[i] + (fd[i] - fa[i]) * p0111[
  168.         i] + (fa[i]) * p1111[i]
  169.     i19 = i = np.logical_and.reduce((~(fab), ~(fac), ~i17[:, None], ~i18[:, None], fbc, fbd)).squeeze(1)
  170.     out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fd[i]) * p0100[i] + (fd[i] - fc[i]) * p0101[i] + (fc[i] - fa[i]) * p0111[
  171.         i] + (fa[i]) * p1111[i]
  172.     i20 = i = np.logical_and.reduce((~(fab), ~(fac), ~i17[:, None], ~i18[:, None], ~i19[:, None], fbc)).squeeze(1)
  173.     out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fb[i]) * p0001[i] + (fb[i] - fc[i]) * p0101[i] + (fc[i] - fa[i]) * p0111[
  174.         i] + (fa[i]) * p1111[i]
  175.     i21 = i = np.logical_and.reduce((~(fab), ~(fac), ~(fbc), fad)).squeeze(1)
  176.     out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fb[i]) * p0010[i] + (fb[i] - fa[i]) * p0110[i] + (fa[i] - fd[i]) * p1110[
  177.         i] + (fd[i]) * p1111[i]
  178.     i22 = i = np.logical_and.reduce((~(fab), ~(fac), ~(fbc), ~i21[:, None], fbd)).squeeze(1)
  179.     out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fb[i]) * p0010[i] + (fb[i] - fd[i]) * p0110[i] + (fd[i] - fa[i]) * p0111[
  180.         i] + (fa[i]) * p1111[i]
  181.     i23 = i = np.logical_and.reduce((~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], fcd)).squeeze(1)
  182.     out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fd[i]) * p0010[i] + (fd[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[
  183.         i] + (fa[i]) * p1111[i]
  184.     i24 = i = np.logical_and.reduce((~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], ~i23[:, None])).squeeze(1)
  185.     out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[
  186.         i] + (fa[i]) * p1111[i]
  187.     out = out / q
  188.     return out,index_flag
复制代码
可以看到查表之前,需要盘算一个index_flag,index_flag的定义即是否满意对角线条件,假如满意对角线条件就是通过对角线LUT去查表,否则我们是采取非对角线的LUT去查表,详细的逻辑各人可以去捋一捋,博主认为实际运行也很少会使用python去跑。

以上针对于SPFLUT代码实现的部门讲解完毕,假如有不清晰的题目欢迎各人提出。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
继续阅读请点击广告

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

×
回复

使用道具 举报

×
登录参与点评抽奖,加入IT实名职场社区
去登录
快速回复 返回顶部 返回列表