目录
原文概要
1. 训练
2. 压缩并转表
3. 微调
4. 测试
本文是对SPFLUT技术的代码解读,原文解读请看SPFLUT。
原文概要
SPFLUT方法重点在于对角线优先压缩策略,该方法总体流程分为4个部门,训练、转换(内里包含了压缩)、微调、测试。其代码的总体结构如下:
可以看到流程与MULUT基本同等,只不过在第二步转换之前还有一步对LUT进行压缩的过程,即2_compress_lut_from_net.py文件。另外第三步的微调中也有针对压缩后的LUT进行微调的代码。
1. 训练
这里我们可以从sr/model.py中,获取到SPF_LUT_net模型的代码实现如下:
- class SPF_LUT_net(nn.Module):
- def __init__(self, nf=32, scale=4, modes=['s', 'd', 'y'], stages=2):
- super(SPF_LUT_net, self).__init__()
- self.upscale = scale
- self.modes = modes
- self.convblock1 = ConvBlock(1, 2, scale=None, output_quant=False, modes=modes, nf=nf)
- self.convblock2 = ConvBlock(1, 2, scale=None, output_quant=False, modes=modes, nf=nf)
- self.convblock3 = ConvBlock(1, 2, scale=None, output_quant=False, modes=modes, nf=nf)
- self.convblock4 = ConvBlock(1, 1, scale=None, output_quant=False, modes=modes, nf=nf)
- self.ChannelConv = MuLUTcUnit(in_c=4, out_c=4, mode='1x1', nf=nf)
- self.upblock = ConvBlock(4, 1, scale=scale, output_quant=False, modes=modes, nf=nf)
- def forward(self, x, phase='train'):
- B, C, H, W = x.size()
- x = x.reshape((B * C, 1, H, W))
- refine_list = []
- # block1
- x = self.convblock1(x)
- avg_factor, bias, norm = len(self.modes) * 4, 127, 255.0
- x = round_func(torch.clamp((x / avg_factor) + bias, 0, 255)) / norm
- refine_list.append(x[:, 0:1, :, :])
- x = x[:, 1:, :, :]
- # block2
- x = self.convblock2(x)
- avg_factor, bias, norm = len(self.modes) * 4, 127, 255.0
- x = round_func(torch.clamp((x / avg_factor) + bias, 0, 255)) / norm
- refine_list.append(x[:, 0:1, :, :])
- x = x[:, 1:, :, :]
- # block3
- x = self.convblock3(x)
- avg_factor, bias, norm = len(self.modes) * 4, 127, 255.0
- x = round_func(torch.clamp((x / avg_factor) + bias, 0, 255)) / norm
- refine_list.append(x[:, 0:1, :, :])
- x = x[:, 1:, :, :]
- # block4
- x = self.convblock4(x)
- avg_factor, bias, norm = len(self.modes) * 4, 127, 255.0
- x = round_func(torch.clamp((x / avg_factor) + bias, 0, 255)) / norm
- refine_list.append(x)
- x = torch.cat(refine_list, dim=1)
- x = round_func(torch.tanh(self.ChannelConv(x)) * 127.0)
- x = round_func(torch.clamp(x + 127, 0, 255)) / 255.0
- x = self.upblock(x)
- avg_factor, bias, norm = len(self.modes), 0, 1
- x = round_func((x / avg_factor) + bias)
- if phase == 'train':
- x = x / 255.0
- x = x.reshape((B, C, self.upscale * H, self.upscale * W))
- return x
复制代码 通过上述代码可以看出,SPFLUT模型主要由两个子模块构成,一个是ConvBlock,另一个是MuLUTcUnit,此中ConvBlock的实现如下:
- class ConvBlock(nn.Module):
- def __init__(self, in_c, out_c, scale=None, output_quant=False, modes=['s', 'd', 'y'], nf=64):
- super(ConvBlock, self).__init__()
- self.in_c = in_c
- self.out_c = out_c
- self.modes = modes
- self.module_dict = dict()
- self.upscale = scale
- self.output_quant = output_quant
- scale_factor = 1 if scale is None else scale ** 2
- for c in range(in_c):
- for mode in modes:
- self.module_dict['DepthwiseBlock{}_{}'.format(c, mode)] = MuLUTConv('{}x{}'.format(mode.upper(), 'N'),
- nf=nf, out_c=out_c * scale_factor,
- stride=1)
- self.module_dict = nn.ModuleDict(self.module_dict)
- if scale is None:
- self.pixel_shuffle = identity
- else:
- self.pixel_shuffle = nn.PixelShuffle(scale)
- def forward(self, x):
- modes = self.modes
- x_out = 0
- for c in range(self.in_c):
- x_c = x[:, c:c + 1, :, :]
- pred = 0
- for mode in modes:
- pad = mode_pad_dict[mode]
- sub_module = self.module_dict['DepthwiseBlock{}_{}'.format(c, mode)]
- for r in [0, 1, 2, 3]:
- pred += round_func(torch.tanh(torch.rot90(self.pixel_shuffle(
- sub_module(F.pad(torch.rot90(x_c, r, [2, 3]), (0, pad, 0, pad), mode='replicate'))),
- (4 - r) % 4, [2, 3])) * 127)
- x_out += pred
- if self.output_quant:
- avg_factor = len(modes) * 4 * self.in_c
- x = round_func(torch.clamp(x_out / avg_factor, -1, 1) * 127) / 127
- else:
- x = x_out / self.in_c
- return x
复制代码 也是由MuLUTConv构成的,位于common/network.py中,而这个模块我们在MuLUT论文代码讲解中有提到,是一个由3种不同范例S、D、Y的kernel组成的一个RF=3x3的模块,这里还需要旋转和clamp等操作,防止每层的效果溢出。
而MuLUTcUnit即是通道上的MuLUT模块,位于common/network.py中,由于只在通道上操作,因此kernel_size上是1,主要建立起特性通道之间的关联。
整体结构是比力清晰的,尤其是对MuLUT的子模块熟悉的情况下,同样的,不清晰的读者可以初始化一个模型来渐渐推理tensor的shape来熟悉。
2. 压缩并转表
这部门代码位于2_compress_lut_from_net.py中,整体流程如下:
- def compress_SPFLUT(opt):
- def save_SPFLUT_DFC(x, lut_path, module):
- # Split input to not over GPU memory
- B = x.size(0) // 100
- outputs = []
- # Extract input-output pairs
- with torch.no_grad():
- model_G.eval()
- for b in range(100):
- if b == 99:
- batch_input = x[b * B:]
- else:
- batch_input = x[b * B:(b + 1) * B]
- batch_output = module(batch_input)
- results = torch.round(torch.tanh(batch_output) * 127).cpu().data.numpy().astype(np.int8)
- outputs += [results]
- results = np.concatenate(outputs, 0)
- results = results.reshape(x.size(0), -1)
- np.save(lut_path, results)
- print("Resulting LUT size: ", results.shape, "Saved to", lut_path)
- modes = [i for i in opt.modes]
- stages = opt.stages
- model = getattr(Model, 'SPF_LUT_net')
- model_G = model(nf=opt.nf, scale=opt.scale, modes=modes, stages=stages).cuda()
- lm = torch.load(os.path.join(opt.expDir, 'Model_{:06d}.pth'.format(opt.loadIter)))
- model_G.load_state_dict(lm, strict=True)
- input_tensor = get_input_tensor(opt)
- for mode in modes:
- if opt.cd == 'xyzt':
- input_tensor_c1 = compress_lut_xyzt(opt, input_tensor)
- elif opt.cd == 'xyz':
- input_tensor_c1 = compress_lut_xyz(opt, input_tensor)
- elif opt.cd == 'xy':
- input_tensor_c1 = compress_lut(opt, input_tensor)
- else:
- raise ValueError
- input_tensor_c2 = compress_lut_larger_interval(opt, input_tensor)
- if mode != 's':
- input_tensor_c1 = get_mode_input_tensor(input_tensor_c1, mode)
- input_tensor_c2 = get_mode_input_tensor(input_tensor_c2, mode)
- # conv1
- module = model_G.convblock1.module_dict['DepthwiseBlock{}_{}'.format(0, mode)]
- lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress1.npy'.format(opt.lutName, 1, mode))
- save_SPFLUT_DFC(input_tensor_c1, lut_path, module)
- lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress2.npy'.format(opt.lutName, 1, mode))
- save_SPFLUT_DFC(input_tensor_c2, lut_path, module)
- # conv2
- module = model_G.convblock2.module_dict['DepthwiseBlock{}_{}'.format(0, mode)]
- lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress1.npy'.format(opt.lutName, 2, mode))
- save_SPFLUT_DFC(input_tensor_c1, lut_path, module)
- lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress2.npy'.format(opt.lutName, 2, mode))
- save_SPFLUT_DFC(input_tensor_c2, lut_path, module)
- # conv3
- module = model_G.convblock3.module_dict['DepthwiseBlock{}_{}'.format(0, mode)]
- lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress1.npy'.format(opt.lutName, 3, mode))
- save_SPFLUT_DFC(input_tensor_c1, lut_path, module)
- lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress2.npy'.format(opt.lutName, 3, mode))
- save_SPFLUT_DFC(input_tensor_c2, lut_path, module)
- # conv4
- module = model_G.convblock4.module_dict['DepthwiseBlock{}_{}'.format(0, mode)]
- lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress1.npy'.format(opt.lutName, 4, mode))
- save_SPFLUT_DFC(input_tensor_c1, lut_path, module)
- lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress2.npy'.format(opt.lutName, 4, mode))
- save_SPFLUT_DFC(input_tensor_c2, lut_path, module)
- # conv6
- for c in range(4):
- module = model_G.upblock.module_dict['DepthwiseBlock{}_{}'.format(c, mode)]
- lut_path = os.path.join(opt.expDir, '{}_s{}c{}_{}_compress1.npy'.format(opt.lutName, 6,c, mode))
- save_SPFLUT_DFC(input_tensor_c1, lut_path, module)
- lut_path = os.path.join(opt.expDir, '{}_s{}c{}_{}_compress2.npy'.format(opt.lutName, 6,c, mode))
- save_SPFLUT_DFC(input_tensor_c2, lut_path, module)
- # conv5
- input_tensor = input_tensor.reshape((-1,4,1,1))
- module = model_G.ChannelConv
- lut_path = os.path.join(opt.expDir, '{}_s{}_channel.npy'.format(opt.lutName, 5))
- save_SPFLUT_DFC(input_tensor, lut_path, module)
复制代码 这里需要关注的细节是对角线压缩相干的3个函数:compress_lut_xyzt、compress_lut_xyz、compress_lut,对应于4维、3维和2维的压缩过程,以及非对角线压缩相干的函数compress_lut_larger_interval,最后我们可以发现对于通道的卷积conv5,作者是没有进行压缩的,由于通道conv不满意对角线先验,故不能进行对角线优先的压缩。
针对于对角线相干的函数:以2维压缩为例,跟我们之前的讲解是一样的。
- def compress_lut(opt, input_tensor):
- base = torch.arange(0, 257, 2 ** opt.interval) # 0-256
- base[-1] -= 1
- L = base.size(0)
- d = opt.dw
- diag = 2 * d + 1
- N = diag * L + (1 - diag ** 2) // 4
- input_tensor = input_tensor.reshape(L * L, L, L, 1, 2, 2)
- index_i = torch.zeros((N,)).type(torch.int64)
- index_j = torch.zeros((N,)).type(torch.int64)
- cnt = 0
- ref2index = np.zeros((L, diag), dtype=np.int_) - 1
- for i in range(L):
- for j in range(L):
- if abs(i - j) <= d:
- index_i[cnt] = i
- index_j[cnt] = j
- ref2index[i, j - i] = cnt
- cnt += 1
- np.save(os.path.join(opt.expDir, 'ref2index_{}{}i{}.npy'.format(opt.cd, opt.dw, opt.si)),ref2index)
- index_compress = index_i * L + index_j
- compressed_input_tensor = input_tensor[index_compress, ...].reshape(-1, 1, 2, 2)
- return compressed_input_tensor
复制代码 作者是通过改变input_tensor来实现这个过程,我们需要取到2维tensor,满意对角线距离条件的所有位置,那这里opt.dw(变量d)对应于我们前面讲解中提到的,满意条件的将其放入ref2index中,并使得cnt加1,这样我们可以将对角线的位置进行保存。
至于L,是我们前面不停在用的与隔断interval相干的个数,一样平常即是17(4bit采样)。而N是我们前面推理算过的索引的总个数(各人可以带入diag来盘算N,这样可以跟公式完全对应),至此2维的一个输入tensor就全部对应完毕,送入模型盘算就可以了,这样子把对角线的位置进行了优先保存。
针对于非对角线的位置:看compress_lut_larger_interval函数,实现如下。
- def compress_lut_larger_interval(opt, input_tensor):
- base = torch.arange(0, 257, 2 ** opt.interval) # 0-256
- base[-1] -= 1
- L = base.size(0)
- input_tensor = input_tensor.reshape(L, L, L, L, 1, 2, 2)
- if opt.si==5:
- k = 2
- elif opt.si==6:
- k = 4
- elif opt.si==7:
- k = 8
- else:
- raise ValueError
- compressed_input_tensor = input_tensor[::k, ::k, ::k, ::k, ...].reshape(-1, 1, 2, 2)
- return compressed_input_tensor
复制代码 比力简朴,即选用一个更大的比例,由于我们前面已经使用了4bit来做隔断,那么当opt.si为5时,我们需要对当前的input_tensor做2隔断的采样就可以,之后都是同理可得。
针对于通道:那我们已经讲到了通道是不可以进行压缩的,因此它的input_tensor是不变的,跟之前一样,实现如下,这个过程我们是比力熟悉的,(假如不停有看LUT系列的文章。还不了解的可以关注一下LUT专题哦):
- def get_input_tensor(opt):
- # 1D input
- base = torch.arange(0, 257, 2 ** opt.interval) # 0-256
- base[-1] -= 1
- L = base.size(0)
- # 2D input
- # 256*256 0 0 0... |1 1 1... |...|255 255 255...
- first = base.cuda().unsqueeze(1).repeat(1, L).reshape(-1)
- # 256*256 0 1 2 .. 255|0 1 2 ... 255|...|0 1 2 ... 255
- second = base.cuda().repeat(L)
- onebytwo = torch.stack([first, second], 1) # [256*256, 2]
- # 3D input
- # 256*256*256 0 x65536|1 x65536|...|255 x65536
- third = base.cuda().unsqueeze(1).repeat(1, L * L).reshape(-1)
- onebytwo = onebytwo.repeat(L, 1)
- onebythree = torch.cat(
- [third.unsqueeze(1), onebytwo], 1) # [256*256*256, 3]
- # 4D input
- fourth = base.cuda().unsqueeze(1).repeat(1, L * L * L).reshape(
- -1) # 256*256*256*256 0 x16777216|1 x16777216|...|255 x16777216
- onebythree = onebythree.repeat(L, 1)
- # [256*256*256*256, 4]
- onebyfourth = torch.cat([fourth.unsqueeze(1), onebythree], 1)
- # Rearange input: [N, 4] -> [N, C=1, H=2, W=2]
- input_tensor = onebyfourth.unsqueeze(1).unsqueeze(
- 1).reshape(-1, 1, 2, 2).float() / 255.0
- return input_tensor
复制代码 3. 微调
微调的部门实在跟MuLUT对比,无明显变革,主要还是看作者如何构建SPF_LUT模型,位置在sr/model.py中,代码如下:
- class SPF_LUT(nn.Module):
- """ PyTorch version of MuLUT for LUT-aware fine-tuning. """
- def __init__(self, lut_folder, stages, modes, lutName, upscale, interval, phase=None, **kwargs):
- super(SPF_LUT, self).__init__()
- self.interval = interval
- self.upscale = upscale
- self.modes = modes
- self.stages = stages
- L = 2 ** (8 - interval) + 1
- for mode in modes:
- # conv1
- lut_path = os.path.join(lut_folder, '{}_s{}c0_{}.npy'.format(lutName, 1, mode))
- # lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c0_{}.npy'.format(1, mode))
- key = "s{}c0_{}".format(1, mode)
- lut_arr = np.load(lut_path).reshape((-1, 2)).astype(np.float32) / 127.0
- self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))
- # conv2
- lut_path = os.path.join(lut_folder, '{}_s{}c0_{}.npy'.format(lutName, 2, mode))
- # lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c0_{}.npy'.format(2, mode))
- key = "s{}c0_{}".format(2, mode)
- lut_arr = np.load(lut_path).reshape((-1, 2)).astype(np.float32) / 127.0
- self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))
- # conv3
- lut_path = os.path.join(lut_folder, '{}_s{}c0_{}.npy'.format(lutName, 3, mode))
- # lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c0_{}.npy'.format(3, mode))
- key = "s{}c0_{}".format(3, mode)
- lut_arr = np.load(lut_path).reshape((-1, 2)).astype(np.float32) / 127.0
- self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))
- # conv4
- lut_path = os.path.join(lut_folder, '{}_s{}c0_{}.npy'.format(lutName, 4, mode))
- # lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c0_{}.npy'.format(4, mode))
- key = "s{}c0_{}".format(4, mode)
- lut_arr = np.load(lut_path).reshape((-1, 1)).astype(np.float32) / 127.0
- self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))
- for c in range(4):
- # conv6
- lut_path = os.path.join(lut_folder, '{}_s{}c{}_{}.npy'.format(lutName, 6,c, mode))
- # lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c{}_{}.npy'.format(6,c, mode))
- key = "s{}c{}_{}".format(6,c, mode)
- lut_arr = np.load(lut_path).reshape((-1, self.upscale * self.upscale)).astype(np.float32) / 127.0
- self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))
- # conv5
- lut_path = os.path.join(lut_folder, '{}_s{}_channel.npy'.format(lutName, 5))
- # lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}_channel.npy'.format(5))
- key = "s{}_channel".format(5)
- lut_arr = np.load(lut_path).reshape((-1, 4)).astype(np.float32) / 127.0
- self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))
复制代码 你会发现,实在跟MuLUT一样,将LUT给register成可训练的parameter,这样子去做一个微调。
4. 测试
测试的部门由于我们的LUT做了改变,修改为了对角线和非对角线,因此在最后的查表推理的部门需要做一些改变,以对角线做2维压缩为例,在sr/4_test_SPF_LUT_DFC.py中。
- def InterpTorchBatch_compress_xy(weight, img_in, h, w, interval, rot, d, upscale=4, out_c=1, mode='s',ref2index=None):
- q = 2 ** interval # 16
- L = 2 ** (8 - interval) + 1 # 17
- diag = 2 * d + 1
- N = diag * L + (1 - diag ** 2) // 4
- if mode == "s":
- img_x = img_in[:, :, 0:0 + h, 0:0 + w]
- img_y = img_in[:, :, 0:0 + h, 1:1 + w]
- index_flag = (np.abs(img_x - img_y) <= d * q)
- # Extract MSBs
- img_a1 = img_in[:, :, 0:0 + h, 0:0 + w] // q
- img_b1 = img_in[:, :, 0:0 + h, 1:1 + w] // q
- img_c1 = img_in[:, :, 1:1 + h, 0:0 + w] // q
- img_d1 = img_in[:, :, 1:1 + h, 1:1 + w] // q
- # Extract LSBs
- fa = img_in[:, :, 0:0 + h, 0:0 + w] % q
- fb = img_in[:, :, 0:0 + h, 1:1 + w] % q
- fc = img_in[:, :, 1:1 + h, 0:0 + w] % q
- fd = img_in[:, :, 1:1 + h, 1:1 + w] % q
- elif mode == 'd':
- img_x = img_in[:, :, 0:0 + h, 0:0 + w]
- img_y = img_in[:, :, 0:0 + h, 2:2 + w]
- index_flag = (np.abs(img_x - img_y) <= d * q)
- img_a1 = img_in[:, :, 0:0 + h, 0:0 + w] // q
- img_b1 = img_in[:, :, 0:0 + h, 2:2 + w] // q
- img_c1 = img_in[:, :, 2:2 + h, 0:0 + w] // q
- img_d1 = img_in[:, :, 2:2 + h, 2:2 + w] // q
- fa = img_in[:, :, 0:0 + h, 0:0 + w] % q
- fb = img_in[:, :, 0:0 + h, 2:2 + w] % q
- fc = img_in[:, :, 2:2 + h, 0:0 + w] % q
- fd = img_in[:, :, 2:2 + h, 2:2 + w] % q
- elif mode == 'y':
- img_x = img_in[:, :, 0:0 + h, 0:0 + w]
- img_y = img_in[:, :, 1:1 + h, 1:1 + w]
- index_flag = (np.abs(img_x - img_y) <= d * q)
- img_a1 = img_in[:, :, 0:0 + h, 0:0 + w] // q
- img_b1 = img_in[:, :, 1:1 + h, 1:1 + w] // q
- img_c1 = img_in[:, :, 1:1 + h, 2:2 + w] // q
- img_d1 = img_in[:, :, 2:2 + h, 1:1 + w] // q
- fa = img_in[:, :, 0:0 + h, 0:0 + w] % q
- fb = img_in[:, :, 1:1 + h, 1:1 + w] % q
- fc = img_in[:, :, 1:1 + h, 2:2 + w] % q
- fd = img_in[:, :, 2:2 + h, 1:1 + w] % q
- else:
- # more sampling modes can be implemented similarly
- raise ValueError("Mode {} not implemented.".format(mode))
- img_a1 = img_a1[index_flag].flatten().astype(np.int_)
- img_b1 = img_b1[index_flag].flatten().astype(np.int_)
- img_c1 = img_c1[index_flag].flatten().astype(np.int_)
- img_d1 = img_d1[index_flag].flatten().astype(np.int_)
- fa = fa[index_flag].flatten()
- fb = fb[index_flag].flatten()
- fc = fc[index_flag].flatten()
- fd = fd[index_flag].flatten()
- img_a2 = img_a1 + 1
- img_b2 = img_b1 + 1
- img_c2 = img_c1 + 1
- img_d2 = img_d1 + 1
- k00 = ref2index[img_a1, img_b1 - img_a1]
- k01 = ref2index[img_a1, img_b2 - img_a1]
- k10 = ref2index[img_a2, img_b1 - img_a2]
- k11 = ref2index[img_a2, img_b2 - img_a2]
- p0000 = weight[k00,img_c1, img_d1].reshape((-1, out_c,upscale,upscale))
- p0001 = weight[k00,img_c1, img_d2].reshape((-1, out_c,upscale,upscale))
- p0010 = weight[k00,img_c2, img_d1].reshape((-1, out_c,upscale,upscale))
- p0011 = weight[k00,img_c2, img_d2].reshape((-1, out_c,upscale,upscale))
- p0100 = weight[k01,img_c1, img_d1].reshape((-1, out_c,upscale,upscale))
- p0101 = weight[k01,img_c1, img_d2].reshape((-1, out_c,upscale,upscale))
- p0110 = weight[k01,img_c2, img_d1].reshape((-1, out_c,upscale,upscale))
- p0111 = weight[k01,img_c2, img_d2].reshape((-1, out_c,upscale,upscale))
- p1000 = weight[k10,img_c1, img_d1].reshape((-1, out_c,upscale,upscale))
- p1001 = weight[k10,img_c1, img_d2].reshape((-1, out_c,upscale,upscale))
- p1010 = weight[k10,img_c2, img_d1].reshape((-1, out_c,upscale,upscale))
- p1011 = weight[k10,img_c2, img_d2].reshape((-1, out_c,upscale,upscale))
- p1100 = weight[k11,img_c1, img_d1].reshape((-1, out_c,upscale,upscale))
- p1101 = weight[k11,img_c1, img_d2].reshape((-1, out_c,upscale,upscale))
- p1110 = weight[k11,img_c2, img_d1].reshape((-1, out_c,upscale,upscale))
- p1111 = weight[k11,img_c2, img_d2].reshape((-1, out_c,upscale,upscale))
- # Output image holder
- out = np.zeros((img_a1.shape[0],out_c, upscale, upscale))
- sz = img_a1.shape[0]
- out = out.reshape(sz, -1)
- p0000 = p0000.reshape(sz, -1)
- p0100 = p0100.reshape(sz, -1)
- p1000 = p1000.reshape(sz, -1)
- p1100 = p1100.reshape(sz, -1)
- fa = fa.reshape(-1, 1)
- p0001 = p0001.reshape(sz, -1)
- p0101 = p0101.reshape(sz, -1)
- p1001 = p1001.reshape(sz, -1)
- p1101 = p1101.reshape(sz, -1)
- fb = fb.reshape(-1, 1)
- fc = fc.reshape(-1, 1)
- p0010 = p0010.reshape(sz, -1)
- p0110 = p0110.reshape(sz, -1)
- p1010 = p1010.reshape(sz, -1)
- p1110 = p1110.reshape(sz, -1)
- fd = fd.reshape(-1, 1)
- p0011 = p0011.reshape(sz, -1)
- p0111 = p0111.reshape(sz, -1)
- p1011 = p1011.reshape(sz, -1)
- p1111 = p1111.reshape(sz, -1)
- fab = fa > fb;
- fac = fa > fc;
- fad = fa > fd
- fbc = fb > fc;
- fbd = fb > fd;
- fcd = fc > fd
- i1 = i = np.logical_and.reduce((fab, fbc, fcd)).squeeze(1)
- # print(p0000[i].shape,fa[i].shape,i.shape,out_c)
- out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fb[i]) * p1000[i] + (fb[i] - fc[i]) * p1100[i] + (fc[i] - fd[i]) * p1110[
- i] + (fd[i]) * p1111[i]
- i2 = i = np.logical_and.reduce((~i1[:, None], fab, fbc, fbd)).squeeze(1)
- out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fb[i]) * p1000[i] + (fb[i] - fd[i]) * p1100[i] + (fd[i] - fc[i]) * p1101[
- i] + (fc[i]) * p1111[i]
- i3 = i = np.logical_and.reduce((~i1[:, None], ~i2[:, None], fab, fbc, fad)).squeeze(1)
- out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fd[i]) * p1000[i] + (fd[i] - fb[i]) * p1001[i] + (fb[i] - fc[i]) * p1101[
- i] + (fc[i]) * p1111[i]
- i4 = i = np.logical_and.reduce((~i1[:, None], ~i2[:, None], ~i3[:, None], fab, fbc)).squeeze(1)
- out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fa[i]) * p0001[i] + (fa[i] - fb[i]) * p1001[i] + (fb[i] - fc[i]) * p1101[
- i] + (fc[i]) * p1111[i]
- i5 = i = np.logical_and.reduce((~(fbc), fab, fac, fbd)).squeeze(1)
- out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fc[i]) * p1000[i] + (fc[i] - fb[i]) * p1010[i] + (fb[i] - fd[i]) * p1110[
- i] + (fd[i]) * p1111[i]
- i6 = i = np.logical_and.reduce((~(fbc), ~i5[:, None], fab, fac, fcd)).squeeze(1)
- out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fc[i]) * p1000[i] + (fc[i] - fd[i]) * p1010[i] + (fd[i] - fb[i]) * p1011[
- i] + (fb[i]) * p1111[i]
- i7 = i = np.logical_and.reduce((~(fbc), ~i5[:, None], ~i6[:, None], fab, fac, fad)).squeeze(1)
- out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fd[i]) * p1000[i] + (fd[i] - fc[i]) * p1001[i] + (fc[i] - fb[i]) * p1011[
- i] + (fb[i]) * p1111[i]
- i8 = i = np.logical_and.reduce((~(fbc), ~i5[:, None], ~i6[:, None], ~i7[:, None], fab, fac)).squeeze(1)
- out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fa[i]) * p0001[i] + (fa[i] - fc[i]) * p1001[i] + (fc[i] - fb[i]) * p1011[
- i] + (fb[i]) * p1111[i]
- i9 = i = np.logical_and.reduce((~(fbc), ~(fac), fab, fbd)).squeeze(1)
- out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fa[i]) * p0010[i] + (fa[i] - fb[i]) * p1010[i] + (fb[i] - fd[i]) * p1110[
- i] + (fd[i]) * p1111[i]
- # Fix the overflow bug in SR-LUT's implementation, should compare fd with fa first!
- # i10 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:,None], fab, fcd)).squeeze(1)
- # out[i] = (q-fc[i]) * p0000[i] + (fc[i]-fa[i]) * p0010[i] + (fa[i]-fd[i]) * p1010[i] + (fd[i]-fb[i]) * p1011[i] + (fb[i]) * p1111[i]
- # i11 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:,None], ~i10[:,None], fab, fad)).squeeze(1)
- # out[i] = (q-fc[i]) * p0000[i] + (fc[i]-fd[i]) * p0010[i] + (fd[i]-fa[i]) * p0011[i] + (fa[i]-fb[i]) * p1011[i] + (fb[i]) * p1111[i]
- i10 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:, None], fab, fad)).squeeze(1) # c > a > d > b
- out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fa[i]) * p0010[i] + (fa[i] - fd[i]) * p1010[i] + (fd[i] - fb[i]) * p1011[
- i] + (fb[i]) * p1111[i]
- i11 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:, None], ~i10[:, None], fab, fcd)).squeeze(1) # c > d > a > b
- out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fd[i]) * p0010[i] + (fd[i] - fa[i]) * p0011[i] + (fa[i] - fb[i]) * p1011[
- i] + (fb[i]) * p1111[i]
- i12 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:, None], ~i10[:, None], ~i11[:, None], fab)).squeeze(1)
- out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fa[i]) * p0011[i] + (fa[i] - fb[i]) * p1011[
- i] + (fb[i]) * p1111[i]
- i13 = i = np.logical_and.reduce((~(fab), fac, fcd)).squeeze(1)
- out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fa[i]) * p0100[i] + (fa[i] - fc[i]) * p1100[i] + (fc[i] - fd[i]) * p1110[
- i] + (fd[i]) * p1111[i]
- i14 = i = np.logical_and.reduce((~(fab), ~i13[:, None], fac, fad)).squeeze(1)
- out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fa[i]) * p0100[i] + (fa[i] - fd[i]) * p1100[i] + (fd[i] - fc[i]) * p1101[
- i] + (fc[i]) * p1111[i]
- i15 = i = np.logical_and.reduce((~(fab), ~i13[:, None], ~i14[:, None], fac, fbd)).squeeze(1)
- out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fd[i]) * p0100[i] + (fd[i] - fa[i]) * p0101[i] + (fa[i] - fc[i]) * p1101[
- i] + (fc[i]) * p1111[i]
- i16 = i = np.logical_and.reduce((~(fab), ~i13[:, None], ~i14[:, None], ~i15[:, None], fac)).squeeze(1)
- out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fb[i]) * p0001[i] + (fb[i] - fa[i]) * p0101[i] + (fa[i] - fc[i]) * p1101[
- i] + (fc[i]) * p1111[i]
- i17 = i = np.logical_and.reduce((~(fab), ~(fac), fbc, fad)).squeeze(1)
- out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fc[i]) * p0100[i] + (fc[i] - fa[i]) * p0110[i] + (fa[i] - fd[i]) * p1110[
- i] + (fd[i]) * p1111[i]
- i18 = i = np.logical_and.reduce((~(fab), ~(fac), ~i17[:, None], fbc, fcd)).squeeze(1)
- out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fc[i]) * p0100[i] + (fc[i] - fd[i]) * p0110[i] + (fd[i] - fa[i]) * p0111[
- i] + (fa[i]) * p1111[i]
- i19 = i = np.logical_and.reduce((~(fab), ~(fac), ~i17[:, None], ~i18[:, None], fbc, fbd)).squeeze(1)
- out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fd[i]) * p0100[i] + (fd[i] - fc[i]) * p0101[i] + (fc[i] - fa[i]) * p0111[
- i] + (fa[i]) * p1111[i]
- i20 = i = np.logical_and.reduce((~(fab), ~(fac), ~i17[:, None], ~i18[:, None], ~i19[:, None], fbc)).squeeze(1)
- out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fb[i]) * p0001[i] + (fb[i] - fc[i]) * p0101[i] + (fc[i] - fa[i]) * p0111[
- i] + (fa[i]) * p1111[i]
- i21 = i = np.logical_and.reduce((~(fab), ~(fac), ~(fbc), fad)).squeeze(1)
- out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fb[i]) * p0010[i] + (fb[i] - fa[i]) * p0110[i] + (fa[i] - fd[i]) * p1110[
- i] + (fd[i]) * p1111[i]
- i22 = i = np.logical_and.reduce((~(fab), ~(fac), ~(fbc), ~i21[:, None], fbd)).squeeze(1)
- out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fb[i]) * p0010[i] + (fb[i] - fd[i]) * p0110[i] + (fd[i] - fa[i]) * p0111[
- i] + (fa[i]) * p1111[i]
- i23 = i = np.logical_and.reduce((~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], fcd)).squeeze(1)
- out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fd[i]) * p0010[i] + (fd[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[
- i] + (fa[i]) * p1111[i]
- i24 = i = np.logical_and.reduce((~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], ~i23[:, None])).squeeze(1)
- out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[
- i] + (fa[i]) * p1111[i]
- out = out / q
- return out,index_flag
复制代码 可以看到查表之前,需要盘算一个index_flag,index_flag的定义即是否满意对角线条件,假如满意对角线条件就是通过对角线LUT去查表,否则我们是采取非对角线的LUT去查表,详细的逻辑各人可以去捋一捋,博主认为实际运行也很少会使用python去跑。
以上针对于SPFLUT代码实现的部门讲解完毕,假如有不清晰的题目欢迎各人提出。
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
|