当前位置: 首页 > java >正文

【LUT技术专题】SPFLUT代码解读

目录

原文概要

1. 训练

2. 压缩并转表

3. 微调

4. 测试


本文是对SPFLUT技术的代码解读,原文解读请看SPFLUT。 

原文概要

SPFLUT方法重点在于对角线优先压缩策略,该方法总体流程分为4个部分,训练、转换(里面包含了压缩)、微调、测试。其代码的总体结构如下:

可以看到流程与MULUT基本一致,只不过在第二步转换之前还有一步对LUT进行压缩的过程,即2_compress_lut_from_net.py文件。另外第三步的微调中也有针对压缩后的LUT进行微调的代码。


1. 训练

这里我们可以从sr/model.py中,获取到SPF_LUT_net模型的代码实现如下:

class SPF_LUT_net(nn.Module):def __init__(self, nf=32, scale=4, modes=['s', 'd', 'y'], stages=2):super(SPF_LUT_net, self).__init__()self.upscale = scaleself.modes = modesself.convblock1 = ConvBlock(1, 2, scale=None, output_quant=False, modes=modes, nf=nf)self.convblock2 = ConvBlock(1, 2, scale=None, output_quant=False, modes=modes, nf=nf)self.convblock3 = ConvBlock(1, 2, scale=None, output_quant=False, modes=modes, nf=nf)self.convblock4 = ConvBlock(1, 1, scale=None, output_quant=False, modes=modes, nf=nf)self.ChannelConv = MuLUTcUnit(in_c=4, out_c=4, mode='1x1', nf=nf)self.upblock = ConvBlock(4, 1, scale=scale, output_quant=False, modes=modes, nf=nf)def forward(self, x, phase='train'):B, C, H, W = x.size()x = x.reshape((B * C, 1, H, W))refine_list = []# block1x = self.convblock1(x)avg_factor, bias, norm = len(self.modes) * 4, 127, 255.0x = round_func(torch.clamp((x / avg_factor) + bias, 0, 255)) / normrefine_list.append(x[:, 0:1, :, :])x = x[:, 1:, :, :]# block2x = self.convblock2(x)avg_factor, bias, norm = len(self.modes) * 4, 127, 255.0x = round_func(torch.clamp((x / avg_factor) + bias, 0, 255)) / normrefine_list.append(x[:, 0:1, :, :])x = x[:, 1:, :, :]# block3x = self.convblock3(x)avg_factor, bias, norm = len(self.modes) * 4, 127, 255.0x = round_func(torch.clamp((x / avg_factor) + bias, 0, 255)) / normrefine_list.append(x[:, 0:1, :, :])x = x[:, 1:, :, :]# block4x = self.convblock4(x)avg_factor, bias, norm = len(self.modes) * 4, 127, 255.0x = round_func(torch.clamp((x / avg_factor) + bias, 0, 255)) / normrefine_list.append(x)x = torch.cat(refine_list, dim=1)x = round_func(torch.tanh(self.ChannelConv(x)) * 127.0)x = round_func(torch.clamp(x + 127, 0, 255)) / 255.0x = self.upblock(x)avg_factor, bias, norm = len(self.modes), 0, 1x = round_func((x / avg_factor) + bias)if phase == 'train':x = x / 255.0x = x.reshape((B, C, self.upscale * H, self.upscale * W))return x

通过上述代码可以看出,SPFLUT模型主要由两个子模块构成,一个是ConvBlock,另一个是MuLUTcUnit,其中ConvBlock的实现如下:

class ConvBlock(nn.Module):def __init__(self, in_c, out_c, scale=None, output_quant=False, modes=['s', 'd', 'y'], nf=64):super(ConvBlock, self).__init__()self.in_c = in_cself.out_c = out_cself.modes = modesself.module_dict = dict()self.upscale = scaleself.output_quant = output_quantscale_factor = 1 if scale is None else scale ** 2for c in range(in_c):for mode in modes:self.module_dict['DepthwiseBlock{}_{}'.format(c, mode)] = MuLUTConv('{}x{}'.format(mode.upper(), 'N'),nf=nf, out_c=out_c * scale_factor,stride=1)self.module_dict = nn.ModuleDict(self.module_dict)if scale is None:self.pixel_shuffle = identityelse:self.pixel_shuffle = nn.PixelShuffle(scale)def forward(self, x):modes = self.modesx_out = 0for c in range(self.in_c):x_c = x[:, c:c + 1, :, :]pred = 0for mode in modes:pad = mode_pad_dict[mode]sub_module = self.module_dict['DepthwiseBlock{}_{}'.format(c, mode)]for r in [0, 1, 2, 3]:pred += round_func(torch.tanh(torch.rot90(self.pixel_shuffle(sub_module(F.pad(torch.rot90(x_c, r, [2, 3]), (0, pad, 0, pad), mode='replicate'))),(4 - r) % 4, [2, 3])) * 127)x_out += predif self.output_quant:avg_factor = len(modes) * 4 * self.in_cx = round_func(torch.clamp(x_out / avg_factor, -1, 1) * 127) / 127else:x = x_out / self.in_creturn x

也是由MuLUTConv构成的,位于common/network.py中,而这个模块我们在MuLUT论文代码讲解中有提到,是一个由3种不同类型S、D、Y的kernel组成的一个RF=3x3的模块,这里还需要旋转和clamp等操作,防止每层的结果溢出。

而MuLUTcUnit即是通道上的MuLUT模块,位于common/network.py中,因为只在通道上操作,因此kernel_size上是1,主要建立起特征通道之间的关联。

整体结构是比较清晰的,尤其是对MuLUT的子模块熟悉的情况下,同样的,不清楚的读者可以初始化一个模型来逐步推理tensor的shape来熟悉。


2. 压缩并转表

这部分代码位于2_compress_lut_from_net.py中,整体流程如下:

def compress_SPFLUT(opt):def save_SPFLUT_DFC(x, lut_path, module):# Split input to not over GPU memoryB = x.size(0) // 100outputs = []# Extract input-output pairswith torch.no_grad():model_G.eval()for b in range(100):if b == 99:batch_input = x[b * B:]else:batch_input = x[b * B:(b + 1) * B]batch_output = module(batch_input)results = torch.round(torch.tanh(batch_output) * 127).cpu().data.numpy().astype(np.int8)outputs += [results]results = np.concatenate(outputs, 0)results = results.reshape(x.size(0), -1)np.save(lut_path, results)print("Resulting LUT size: ", results.shape, "Saved to", lut_path)modes = [i for i in opt.modes]stages = opt.stagesmodel = getattr(Model, 'SPF_LUT_net')model_G = model(nf=opt.nf, scale=opt.scale, modes=modes, stages=stages).cuda()lm = torch.load(os.path.join(opt.expDir, 'Model_{:06d}.pth'.format(opt.loadIter)))model_G.load_state_dict(lm, strict=True)input_tensor = get_input_tensor(opt)for mode in modes:if opt.cd == 'xyzt':input_tensor_c1 = compress_lut_xyzt(opt, input_tensor)elif opt.cd == 'xyz':input_tensor_c1 = compress_lut_xyz(opt, input_tensor)elif opt.cd == 'xy':input_tensor_c1 = compress_lut(opt, input_tensor)else:raise ValueErrorinput_tensor_c2 = compress_lut_larger_interval(opt, input_tensor)if mode != 's':input_tensor_c1 = get_mode_input_tensor(input_tensor_c1, mode)input_tensor_c2 = get_mode_input_tensor(input_tensor_c2, mode)# conv1module = model_G.convblock1.module_dict['DepthwiseBlock{}_{}'.format(0, mode)]lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress1.npy'.format(opt.lutName, 1, mode))save_SPFLUT_DFC(input_tensor_c1, lut_path, module)lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress2.npy'.format(opt.lutName, 1, mode))save_SPFLUT_DFC(input_tensor_c2, lut_path, module)# conv2module = model_G.convblock2.module_dict['DepthwiseBlock{}_{}'.format(0, mode)]lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress1.npy'.format(opt.lutName, 2, mode))save_SPFLUT_DFC(input_tensor_c1, lut_path, module)lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress2.npy'.format(opt.lutName, 2, mode))save_SPFLUT_DFC(input_tensor_c2, lut_path, module)# conv3module = model_G.convblock3.module_dict['DepthwiseBlock{}_{}'.format(0, mode)]lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress1.npy'.format(opt.lutName, 3, mode))save_SPFLUT_DFC(input_tensor_c1, lut_path, module)lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress2.npy'.format(opt.lutName, 3, mode))save_SPFLUT_DFC(input_tensor_c2, lut_path, module)# conv4module = model_G.convblock4.module_dict['DepthwiseBlock{}_{}'.format(0, mode)]lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress1.npy'.format(opt.lutName, 4, mode))save_SPFLUT_DFC(input_tensor_c1, lut_path, module)lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress2.npy'.format(opt.lutName, 4, mode))save_SPFLUT_DFC(input_tensor_c2, lut_path, module)# conv6for c in range(4):module = model_G.upblock.module_dict['DepthwiseBlock{}_{}'.format(c, mode)]lut_path = os.path.join(opt.expDir, '{}_s{}c{}_{}_compress1.npy'.format(opt.lutName, 6,c, mode))save_SPFLUT_DFC(input_tensor_c1, lut_path, module)lut_path = os.path.join(opt.expDir, '{}_s{}c{}_{}_compress2.npy'.format(opt.lutName, 6,c, mode))save_SPFLUT_DFC(input_tensor_c2, lut_path, module)# conv5input_tensor = input_tensor.reshape((-1,4,1,1))module = model_G.ChannelConvlut_path = os.path.join(opt.expDir, '{}_s{}_channel.npy'.format(opt.lutName, 5))save_SPFLUT_DFC(input_tensor, lut_path, module)

这里需要关注的细节是对角线压缩相关的3个函数:compress_lut_xyzt、compress_lut_xyz、compress_lut,对应于4维、3维和2维的压缩过程,以及非对角线压缩相关的函数compress_lut_larger_interval,最后我们可以发现对于通道的卷积conv5,作者是没有进行压缩的,因为通道conv不满足对角线先验,故不能进行对角线优先的压缩

针对于对角线相关的函数:以2维压缩为例,跟我们之前的讲解是一样的。

def compress_lut(opt, input_tensor):base = torch.arange(0, 257, 2 ** opt.interval)  # 0-256base[-1] -= 1L = base.size(0)d = opt.dwdiag = 2 * d + 1N = diag * L + (1 - diag ** 2) // 4input_tensor = input_tensor.reshape(L * L, L, L, 1, 2, 2)index_i = torch.zeros((N,)).type(torch.int64)index_j = torch.zeros((N,)).type(torch.int64)cnt = 0ref2index = np.zeros((L, diag), dtype=np.int_) - 1for i in range(L):for j in range(L):if abs(i - j) <= d:index_i[cnt] = iindex_j[cnt] = jref2index[i, j - i] = cntcnt += 1np.save(os.path.join(opt.expDir, 'ref2index_{}{}i{}.npy'.format(opt.cd, opt.dw, opt.si)),ref2index)index_compress = index_i * L + index_jcompressed_input_tensor = input_tensor[index_compress, ...].reshape(-1, 1, 2, 2)return compressed_input_tensor

作者是通过改变input_tensor来实现这个过程,我们需要取到2维tensor,满足对角线距离条件的所有位置,那这里opt.dw(变量d)对应于我们前面讲解中提到的\lambda,满足条件的将其放入ref2index中,并使得cnt加1,这样我们可以将对角线的位置进行保存。

至于L,是我们前面一直在用的与间隔interval相关的个数,一般等于17(4bit采样)。而N是我们前面推理算过的索引的总个数K(大家可以带入diag来计算N,这样可以跟公式完全对应),至此2维的一个输入tensor就全部对应完毕,送入模型计算就可以了,这样子把对角线的位置进行了优先保存。

针对于非对角线的位置:看compress_lut_larger_interval函数,实现如下。


def compress_lut_larger_interval(opt, input_tensor):base = torch.arange(0, 257, 2 ** opt.interval)  # 0-256base[-1] -= 1L = base.size(0)input_tensor = input_tensor.reshape(L, L, L, L, 1, 2, 2)if opt.si==5:k = 2elif opt.si==6:k = 4elif opt.si==7:k = 8else:raise ValueErrorcompressed_input_tensor = input_tensor[::k, ::k, ::k, ::k, ...].reshape(-1, 1, 2, 2)return compressed_input_tensor

比较简单,即选用一个更大的比例,因为我们前面已经使用了4bit来做间隔,那么当opt.si为5时,我们需要对当前的input_tensor做2间隔的采样就可以,之后都是同理可得。

针对于通道:那我们已经讲到了通道是不可以进行压缩的,因此它的input_tensor是不变的,跟之前一样,实现如下,这个过程我们是比较熟悉的,(如果一直有看LUT系列的文章。还不了解的可以关注一下LUT专题哦):

def get_input_tensor(opt):# 1D inputbase = torch.arange(0, 257, 2 ** opt.interval)  # 0-256base[-1] -= 1L = base.size(0)# 2D input# 256*256   0 0 0...    |1 1 1...     |...|255 255 255...first = base.cuda().unsqueeze(1).repeat(1, L).reshape(-1)# 256*256   0 1 2 .. 255|0 1 2 ... 255|...|0 1 2 ... 255second = base.cuda().repeat(L)onebytwo = torch.stack([first, second], 1)  # [256*256, 2]# 3D input# 256*256*256   0 x65536|1 x65536|...|255 x65536third = base.cuda().unsqueeze(1).repeat(1, L * L).reshape(-1)onebytwo = onebytwo.repeat(L, 1)onebythree = torch.cat([third.unsqueeze(1), onebytwo], 1)  # [256*256*256, 3]# 4D inputfourth = base.cuda().unsqueeze(1).repeat(1, L * L * L).reshape(-1)  # 256*256*256*256   0 x16777216|1 x16777216|...|255 x16777216onebythree = onebythree.repeat(L, 1)# [256*256*256*256, 4]onebyfourth = torch.cat([fourth.unsqueeze(1), onebythree], 1)# Rearange input: [N, 4] -> [N, C=1, H=2, W=2]input_tensor = onebyfourth.unsqueeze(1).unsqueeze(1).reshape(-1, 1, 2, 2).float() / 255.0return input_tensor

3. 微调

微调的部分其实跟MuLUT对比,无明显变化,主要还是看作者如何构建SPF_LUT模型,位置在sr/model.py中,代码如下:

class SPF_LUT(nn.Module):""" PyTorch version of MuLUT for LUT-aware fine-tuning. """def __init__(self, lut_folder, stages, modes, lutName, upscale, interval, phase=None, **kwargs):super(SPF_LUT, self).__init__()self.interval = intervalself.upscale = upscaleself.modes = modesself.stages = stagesL = 2 ** (8 - interval) + 1for mode in modes:# conv1lut_path = os.path.join(lut_folder, '{}_s{}c0_{}.npy'.format(lutName, 1, mode))# lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c0_{}.npy'.format(1, mode))key = "s{}c0_{}".format(1, mode)lut_arr = np.load(lut_path).reshape((-1, 2)).astype(np.float32) / 127.0self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))# conv2lut_path = os.path.join(lut_folder, '{}_s{}c0_{}.npy'.format(lutName, 2, mode))# lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c0_{}.npy'.format(2, mode))key = "s{}c0_{}".format(2, mode)lut_arr = np.load(lut_path).reshape((-1, 2)).astype(np.float32) / 127.0self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))# conv3lut_path = os.path.join(lut_folder, '{}_s{}c0_{}.npy'.format(lutName, 3, mode))# lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c0_{}.npy'.format(3, mode))key = "s{}c0_{}".format(3, mode)lut_arr = np.load(lut_path).reshape((-1, 2)).astype(np.float32) / 127.0self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))# conv4lut_path = os.path.join(lut_folder, '{}_s{}c0_{}.npy'.format(lutName, 4, mode))# lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c0_{}.npy'.format(4, mode))key = "s{}c0_{}".format(4, mode)lut_arr = np.load(lut_path).reshape((-1, 1)).astype(np.float32) / 127.0self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))for c in range(4):# conv6lut_path = os.path.join(lut_folder, '{}_s{}c{}_{}.npy'.format(lutName, 6,c, mode))# lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c{}_{}.npy'.format(6,c, mode))key = "s{}c{}_{}".format(6,c, mode)lut_arr = np.load(lut_path).reshape((-1, self.upscale * self.upscale)).astype(np.float32) / 127.0self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))# conv5lut_path = os.path.join(lut_folder, '{}_s{}_channel.npy'.format(lutName, 5))# lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}_channel.npy'.format(5))key = "s{}_channel".format(5)lut_arr = np.load(lut_path).reshape((-1, 4)).astype(np.float32) / 127.0self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))

你会发现,其实跟MuLUT一样,将LUT给register成可训练的parameter,这样子去做一个微调。


4. 测试

测试的部分因为我们的LUT做了改变,修改为了对角线和非对角线,因此在最后的查表推理的部分需要做一些改变,以对角线做2维压缩为例,在sr/4_test_SPF_LUT_DFC.py中。

def InterpTorchBatch_compress_xy(weight, img_in, h, w, interval, rot, d, upscale=4, out_c=1, mode='s',ref2index=None):q = 2 ** interval  # 16L = 2 ** (8 - interval) + 1  # 17diag = 2 * d + 1N = diag * L + (1 - diag ** 2) // 4if mode == "s":img_x = img_in[:, :, 0:0 + h, 0:0 + w]img_y = img_in[:, :, 0:0 + h, 1:1 + w]index_flag = (np.abs(img_x - img_y) <= d * q)# Extract MSBsimg_a1 = img_in[:, :, 0:0 + h, 0:0 + w] // qimg_b1 = img_in[:, :, 0:0 + h, 1:1 + w] // qimg_c1 = img_in[:, :, 1:1 + h, 0:0 + w] // qimg_d1 = img_in[:, :, 1:1 + h, 1:1 + w] // q# Extract LSBsfa = img_in[:, :, 0:0 + h, 0:0 + w] % qfb = img_in[:, :, 0:0 + h, 1:1 + w] % qfc = img_in[:, :, 1:1 + h, 0:0 + w] % qfd = img_in[:, :, 1:1 + h, 1:1 + w] % qelif mode == 'd':img_x = img_in[:, :, 0:0 + h, 0:0 + w]img_y = img_in[:, :, 0:0 + h, 2:2 + w]index_flag = (np.abs(img_x - img_y) <= d * q)img_a1 = img_in[:, :, 0:0 + h, 0:0 + w] // qimg_b1 = img_in[:, :, 0:0 + h, 2:2 + w] // qimg_c1 = img_in[:, :, 2:2 + h, 0:0 + w] // qimg_d1 = img_in[:, :, 2:2 + h, 2:2 + w] // qfa = img_in[:, :, 0:0 + h, 0:0 + w] % qfb = img_in[:, :, 0:0 + h, 2:2 + w] % qfc = img_in[:, :, 2:2 + h, 0:0 + w] % qfd = img_in[:, :, 2:2 + h, 2:2 + w] % qelif mode == 'y':img_x = img_in[:, :, 0:0 + h, 0:0 + w]img_y = img_in[:, :, 1:1 + h, 1:1 + w]index_flag = (np.abs(img_x - img_y) <= d * q)img_a1 = img_in[:, :, 0:0 + h, 0:0 + w] // qimg_b1 = img_in[:, :, 1:1 + h, 1:1 + w] // qimg_c1 = img_in[:, :, 1:1 + h, 2:2 + w] // qimg_d1 = img_in[:, :, 2:2 + h, 1:1 + w] // qfa = img_in[:, :, 0:0 + h, 0:0 + w] % qfb = img_in[:, :, 1:1 + h, 1:1 + w] % qfc = img_in[:, :, 1:1 + h, 2:2 + w] % qfd = img_in[:, :, 2:2 + h, 1:1 + w] % qelse:# more sampling modes can be implemented similarlyraise ValueError("Mode {} not implemented.".format(mode))img_a1 = img_a1[index_flag].flatten().astype(np.int_)img_b1 = img_b1[index_flag].flatten().astype(np.int_)img_c1 = img_c1[index_flag].flatten().astype(np.int_)img_d1 = img_d1[index_flag].flatten().astype(np.int_)fa = fa[index_flag].flatten()fb = fb[index_flag].flatten()fc = fc[index_flag].flatten()fd = fd[index_flag].flatten()img_a2 = img_a1 + 1img_b2 = img_b1 + 1img_c2 = img_c1 + 1img_d2 = img_d1 + 1k00 = ref2index[img_a1, img_b1 - img_a1]k01 = ref2index[img_a1, img_b2 - img_a1]k10 = ref2index[img_a2, img_b1 - img_a2]k11 = ref2index[img_a2, img_b2 - img_a2]p0000 = weight[k00,img_c1, img_d1].reshape((-1, out_c,upscale,upscale))p0001 = weight[k00,img_c1, img_d2].reshape((-1, out_c,upscale,upscale))p0010 = weight[k00,img_c2, img_d1].reshape((-1, out_c,upscale,upscale))p0011 = weight[k00,img_c2, img_d2].reshape((-1, out_c,upscale,upscale))p0100 = weight[k01,img_c1, img_d1].reshape((-1, out_c,upscale,upscale))p0101 = weight[k01,img_c1, img_d2].reshape((-1, out_c,upscale,upscale))p0110 = weight[k01,img_c2, img_d1].reshape((-1, out_c,upscale,upscale))p0111 = weight[k01,img_c2, img_d2].reshape((-1, out_c,upscale,upscale))p1000 = weight[k10,img_c1, img_d1].reshape((-1, out_c,upscale,upscale))p1001 = weight[k10,img_c1, img_d2].reshape((-1, out_c,upscale,upscale))p1010 = weight[k10,img_c2, img_d1].reshape((-1, out_c,upscale,upscale))p1011 = weight[k10,img_c2, img_d2].reshape((-1, out_c,upscale,upscale))p1100 = weight[k11,img_c1, img_d1].reshape((-1, out_c,upscale,upscale))p1101 = weight[k11,img_c1, img_d2].reshape((-1, out_c,upscale,upscale))p1110 = weight[k11,img_c2, img_d1].reshape((-1, out_c,upscale,upscale))p1111 = weight[k11,img_c2, img_d2].reshape((-1, out_c,upscale,upscale))# Output image holderout = np.zeros((img_a1.shape[0],out_c, upscale, upscale))sz = img_a1.shape[0]out = out.reshape(sz, -1)p0000 = p0000.reshape(sz, -1)p0100 = p0100.reshape(sz, -1)p1000 = p1000.reshape(sz, -1)p1100 = p1100.reshape(sz, -1)fa = fa.reshape(-1, 1)p0001 = p0001.reshape(sz, -1)p0101 = p0101.reshape(sz, -1)p1001 = p1001.reshape(sz, -1)p1101 = p1101.reshape(sz, -1)fb = fb.reshape(-1, 1)fc = fc.reshape(-1, 1)p0010 = p0010.reshape(sz, -1)p0110 = p0110.reshape(sz, -1)p1010 = p1010.reshape(sz, -1)p1110 = p1110.reshape(sz, -1)fd = fd.reshape(-1, 1)p0011 = p0011.reshape(sz, -1)p0111 = p0111.reshape(sz, -1)p1011 = p1011.reshape(sz, -1)p1111 = p1111.reshape(sz, -1)fab = fa > fb;fac = fa > fc;fad = fa > fdfbc = fb > fc;fbd = fb > fd;fcd = fc > fdi1 = i = np.logical_and.reduce((fab, fbc, fcd)).squeeze(1)# print(p0000[i].shape,fa[i].shape,i.shape,out_c)out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fb[i]) * p1000[i] + (fb[i] - fc[i]) * p1100[i] + (fc[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i]i2 = i = np.logical_and.reduce((~i1[:, None], fab, fbc, fbd)).squeeze(1)out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fb[i]) * p1000[i] + (fb[i] - fd[i]) * p1100[i] + (fd[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i]i3 = i = np.logical_and.reduce((~i1[:, None], ~i2[:, None], fab, fbc, fad)).squeeze(1)out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fd[i]) * p1000[i] + (fd[i] - fb[i]) * p1001[i] + (fb[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i]i4 = i = np.logical_and.reduce((~i1[:, None], ~i2[:, None], ~i3[:, None], fab, fbc)).squeeze(1)out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fa[i]) * p0001[i] + (fa[i] - fb[i]) * p1001[i] + (fb[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i]i5 = i = np.logical_and.reduce((~(fbc), fab, fac, fbd)).squeeze(1)out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fc[i]) * p1000[i] + (fc[i] - fb[i]) * p1010[i] + (fb[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i]i6 = i = np.logical_and.reduce((~(fbc), ~i5[:, None], fab, fac, fcd)).squeeze(1)out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fc[i]) * p1000[i] + (fc[i] - fd[i]) * p1010[i] + (fd[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]i7 = i = np.logical_and.reduce((~(fbc), ~i5[:, None], ~i6[:, None], fab, fac, fad)).squeeze(1)out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fd[i]) * p1000[i] + (fd[i] - fc[i]) * p1001[i] + (fc[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]i8 = i = np.logical_and.reduce((~(fbc), ~i5[:, None], ~i6[:, None], ~i7[:, None], fab, fac)).squeeze(1)out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fa[i]) * p0001[i] + (fa[i] - fc[i]) * p1001[i] + (fc[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]i9 = i = np.logical_and.reduce((~(fbc), ~(fac), fab, fbd)).squeeze(1)out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fa[i]) * p0010[i] + (fa[i] - fb[i]) * p1010[i] + (fb[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i]# Fix the overflow bug in SR-LUT's implementation, should compare fd with fa first!# i10 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:,None], fab, fcd)).squeeze(1)# out[i] = (q-fc[i]) * p0000[i] + (fc[i]-fa[i]) * p0010[i] + (fa[i]-fd[i]) * p1010[i] + (fd[i]-fb[i]) * p1011[i] + (fb[i]) * p1111[i]# i11 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:,None], ~i10[:,None], fab, fad)).squeeze(1)# out[i] = (q-fc[i]) * p0000[i] + (fc[i]-fd[i]) * p0010[i] + (fd[i]-fa[i]) * p0011[i] + (fa[i]-fb[i]) * p1011[i] + (fb[i]) * p1111[i]i10 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:, None], fab, fad)).squeeze(1)  # c > a > d > bout[i] = (q - fc[i]) * p0000[i] + (fc[i] - fa[i]) * p0010[i] + (fa[i] - fd[i]) * p1010[i] + (fd[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]i11 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:, None], ~i10[:, None], fab, fcd)).squeeze(1)  # c > d > a > bout[i] = (q - fc[i]) * p0000[i] + (fc[i] - fd[i]) * p0010[i] + (fd[i] - fa[i]) * p0011[i] + (fa[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]i12 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:, None], ~i10[:, None], ~i11[:, None], fab)).squeeze(1)out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fa[i]) * p0011[i] + (fa[i] - fb[i]) * p1011[i] + (fb[i]) * p1111[i]i13 = i = np.logical_and.reduce((~(fab), fac, fcd)).squeeze(1)out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fa[i]) * p0100[i] + (fa[i] - fc[i]) * p1100[i] + (fc[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i]i14 = i = np.logical_and.reduce((~(fab), ~i13[:, None], fac, fad)).squeeze(1)out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fa[i]) * p0100[i] + (fa[i] - fd[i]) * p1100[i] + (fd[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i]i15 = i = np.logical_and.reduce((~(fab), ~i13[:, None], ~i14[:, None], fac, fbd)).squeeze(1)out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fd[i]) * p0100[i] + (fd[i] - fa[i]) * p0101[i] + (fa[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i]i16 = i = np.logical_and.reduce((~(fab), ~i13[:, None], ~i14[:, None], ~i15[:, None], fac)).squeeze(1)out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fb[i]) * p0001[i] + (fb[i] - fa[i]) * p0101[i] + (fa[i] - fc[i]) * p1101[i] + (fc[i]) * p1111[i]i17 = i = np.logical_and.reduce((~(fab), ~(fac), fbc, fad)).squeeze(1)out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fc[i]) * p0100[i] + (fc[i] - fa[i]) * p0110[i] + (fa[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i]i18 = i = np.logical_and.reduce((~(fab), ~(fac), ~i17[:, None], fbc, fcd)).squeeze(1)out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fc[i]) * p0100[i] + (fc[i] - fd[i]) * p0110[i] + (fd[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]i19 = i = np.logical_and.reduce((~(fab), ~(fac), ~i17[:, None], ~i18[:, None], fbc, fbd)).squeeze(1)out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fd[i]) * p0100[i] + (fd[i] - fc[i]) * p0101[i] + (fc[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]i20 = i = np.logical_and.reduce((~(fab), ~(fac), ~i17[:, None], ~i18[:, None], ~i19[:, None], fbc)).squeeze(1)out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fb[i]) * p0001[i] + (fb[i] - fc[i]) * p0101[i] + (fc[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]i21 = i = np.logical_and.reduce((~(fab), ~(fac), ~(fbc), fad)).squeeze(1)out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fb[i]) * p0010[i] + (fb[i] - fa[i]) * p0110[i] + (fa[i] - fd[i]) * p1110[i] + (fd[i]) * p1111[i]i22 = i = np.logical_and.reduce((~(fab), ~(fac), ~(fbc), ~i21[:, None], fbd)).squeeze(1)out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fb[i]) * p0010[i] + (fb[i] - fd[i]) * p0110[i] + (fd[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]i23 = i = np.logical_and.reduce((~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], fcd)).squeeze(1)out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fd[i]) * p0010[i] + (fd[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]i24 = i = np.logical_and.reduce((~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], ~i23[:, None])).squeeze(1)out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[i] + (fa[i]) * p1111[i]out = out / qreturn out,index_flag

可以看到查表之前,需要计算一个index_flag,index_flag的定义即是否满足对角线条件,如果满足对角线条件就是通过对角线LUT去查表,否则我们是采用非对角线的LUT去查表,具体的逻辑大家可以去捋一捋,博主认为实际运行也很少会使用python去跑。


以上针对于SPFLUT代码实现的部分讲解完毕,如果有不清楚的问题欢迎大家提出。

http://www.xdnf.cn/news/5848.html

相关文章:

  • Mirror的多人连接管理及房间系统
  • github 上的 CI/CD 的尝试
  • 掌握Multi-Agent实践(五):基于KIMAs的多智能体知识集成系统构建与应用实践
  • 每日算法刷题计划Day5 5.13:leetcode数组3道题,用时1h
  • AFFS2 的 `yaffs_ext_tags` 数据结构详解
  • 大模型MCP_MCP从流式SSE到流式HTTP_1.8.0支持流式HTTP交互_介绍_从应用到最优--人工智能工作笔记0245
  • C++修炼:继承
  • API的学习总结(上)
  • # 08_Elastic Stack 从入门到实践(八)---1
  • 每日Prompt:发光线条解剖图
  • 生信小白学Rust-03
  • 机器学习之决策树模型:从基础概念到条件类型详解
  • 【WIN】笔记本电脑忘记密码解决办法/笔记本电脑重装系统笔记/bitlocker忘记密码的解决办法
  • UDS诊断----------$27诊断服务
  • BFS算法篇——从晨曦到星辰,BFS算法在多源最短路径问题中的诗意航行(上)
  • 3.1 泰勒公式出发点
  • 人脸识别门禁系统技术文档
  • 运行Spark程序-在shell中运行 --SparkConf 和 SparkContext
  • Hadoop和Spark生态系统
  • Java详解LeetCode 热题 100(15):LeetCode 189. 轮转数组(Rotate Array)详解
  • 跨境电商定价革命:亚马逊“逆向提价“策略背后的价值重构逻辑
  • 鸿蒙接入flutter环境变量配置windows-命令行或者手动配置-到项目的创建-运行demo项目
  • (七)深度学习---神经网络原理与实现
  • 在VirtualBox中安装虚拟机后不能全屏显示的问题及解决办法
  • 软考 系统架构设计师系列知识点之杂项集萃(58)
  • 基于Java和PostGIS的AOI面数据球面面积计算实践
  • Kaamel隐私合规洞察:Facebook美容定向广告事件分析
  • Docker环境下的Apache NiFi安装实践踩坑记录
  • 蓝桥杯 16. 外卖店优先级
  • 数据结构——例题1