视频去模糊-STFAN, 达到速度,精度和模型大小同步考虑的SOTA性能

描述

该文是商汤研究院、南京科技以及哈工大联合提出的一种采用动态滤波器卷积进行视频去模糊的方法。

由于相机抖动、目标运动以及景深变化会导致视频中存在spatially variant blur现象。现有的去模糊方法通常在模糊视频中估计光流进行对齐。这种方法会因光流估计的不够精确导致生成的视频存在伪影,或无法有效去除模糊。为克服光流估计的局限性,作者提出了一种的STFAN框架(同时进行对齐和去模糊)。它采用了类似FRVSR的思路进行视频去模糊(它前一帧的模糊以及去模糊图像联合当前帧模糊图像作为输出,经CNN后输出当前帧去模糊图像),其CNN架构采用了空间自使用滤波器进行对齐与去模糊。

作者提出一种新的FAC操作进行对齐,然后从当前帧移除空间可变模糊。最后,作者设计了一个重建网络用于复原清晰的图像。
作者在合成数据与真实数据进行了量化对比分析,所提方法取得了SOTA性能(同时考虑了精度、速度以及模型大小)。

文章作者: Happy

 

Abstract

论文将视频去模糊中的近邻帧对齐与非均匀模糊移除问题建模为element-wise filter adaptive convolution processes。论文的创新点包含:

  • 提出一个滤波器自适应卷积计算单元,它用于特征域的对齐与去模糊;
  • 提出一个新颖的STFAN用于视频去模糊,它集成帧对齐、去模糊到同一个框架中,而无需明显的运动估计;
  • 从精度、速度以及模型大小方面对所提方法进行了量化评估,取得了SOTA性能。

Method


​上图给出了文中所用到的网络架构示意图。从中可以看出,它包含三个子模块:特征提取、STFAN以及重建模块。它的输入包含三个图像(前一帧模糊图像,前一帧去模糊图像以及当前帧模糊图像),由STFAN生成对齐滤波器与去模糊滤波器,然后采用FAC操作进行特征对齐与模糊移除。最后采用重建模块进行清晰图像生成。

FAC

 

 

​关于如何进行代码实现,见文末,这里就不再进行更多的介绍。

网络架构

​如前所述,该网络架构包含三个模块,这里将分别针对三个子模块进行简单的介绍。

  • 特征提取网络。该模块用于从模糊图像$B_t$中提取特征$E_t$,它由三个卷积模块构成,每个卷积模块包含一个stride=2的卷积以及两个残差模块(激活函数选择LeakyReLU)。这里所提取的特征被送入到STFAN模块中采用FAC操作进行去模糊。该网络的配置参数如下:

 

注:上述两个滤波器生成模块均包含一个卷积和两个残差模块并后接一个1x1卷积用于得到期望的输出通道。基于所得到的两组滤波器,采用FAC对前一帧的去模糊特征与当前帧的帧特征进行对齐,同时在特征层面进行模糊移除。最后,将两者进行拼接送入到重建模块中。该模块的配置参数如下:

  • 重建模块。该模块以STFAN中的融合特征作为输入,输出清晰的去模糊图像。该模块的配置参数如下所示。

 

损失函数

​为更有效的训练所提网络,作者考虑如下两种损失函数:

  • MSE Loss。它用于度量去模糊图像R与真实图像S之间的差异。
  • Perceptional Loss。它用于度量去模糊图像R与真实图像S在特征层面的相似性。

总体的损失定义为:
$$ L_{deblur} = L_{mse} + 0.01 * L_{perceptual}$$

Experiments

Dataset

​训练数据源自《Deep video deblurring for hand-held camars》,它包含71个视频(6708对数据),被划分为61用于训练(5708对数据),10个用于测试(1000对数据)。

​在数据增广方面,作者将每个视频划分为长度为20的序列。

​为增广运动的多样性,对序列数据进行随机逆序。对每个序列数据,执行相同的图像变换(包含亮度、对比度调整(从[0.8, 1.2]范围内均匀采样))、几何变换(包含随机水平/垂直镜像)以及裁剪。

​为提升模块在真实场景的鲁棒性,对序列图像添加了高斯(N(0, 0.01))随机噪声。

Training Settings

​整个网络采用kaiming方式进行初始化,采用Adam优化器( [公式] ),初始学习率设为1e-4,每400k迭代乘以0.1。总计迭代次数为900k。

Results

​下面两图给出了不同去模糊方法的量化对比以及视觉效果对比。更多实验结果以及相关分析请查阅原文。
 

小结

​本文提出一种新颖的基于动态滤波器卷积的时空网络用于视频去模块。该网路可以动态的生成用于对齐与去模糊的滤波器。基于所生成滤波器以及FAC单元,该网络可以执行时序对齐与特征去模糊。这种无明显运动估计的方法使得它可以处理动态场景中的空间可变的模块现象。

结合论文所提供的网络架构,以及论文附加文档中所提供的相关参数,简单的整理代码如下:了精度、速度以及模型大小)。

参考代码

结合论文所提供的网络架构,以及论文附加文档中所提供的相关参数,参考代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

# 论文中各个模块中设计的残差模块(文中未提到是否有BN层,因此这里未添加BN)。
class ResBlock(nn.Module):
    def __init__(self, inc):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(inc, inc, 3, 1, 1)
        self.lrelu = nn.LeakyReLU(0.1)
        self.conv2 = nn.Conv2d(inc, inc, 3, 1, 1)
    def forward(self, x):
        res = self.conv2(self.lrelu(self.conv1(x)))
        return res + x

# 特征提取模块
class FeatureExtract(nn.Module):
    def __init__(self):
        super(FeatureExtract, self).__init__()
        self.net = nn.Sequential(nn.Conv2d(3, 32, 3, 1, 1),
                                ResBlock(32),
                                ResBlock(32),
                                nn.Conv2d(32, 64, 3, 2, 1),
                                ResBlock(64),
                                ResBlock(64),
                                nn.Conv2d(64, 128, 3, 2, 1),
                                ResBlock(128),
                                ResBlock(128))
    def forward(self, x):
        return self.net(x)
    
# 重建模块
# (注:ConvTranspose的参数参考附件文档设置,pad参数是估计所得,只有这组参数能够满足论文的相关特征之间的尺寸关系,如有问题,请反馈更新)
class ReconBlock(nn.Module):
    def __init__(self):
        super(ReconBlock, self).__init__()
        self.up1  = nn.ConvTranspose2d(256, 64, 4, 2, 1)
        self.res2 = ResBlock(64)
        self.res3 = ResBlock(64)
        self.up4  = nn.ConvTranspose2d(64, 32, 3, 2, 1)
        self.res5 = ResBlock(32)
        self.res6 = ResBlock(32)
        self.conv7 = nn.Conv2d(32, 3, 3, 1, 1)

    def forward(self, feat, inputs):
        N, C, H, W = feat.size()
        up1  = self.up1(feat, output_size=(H * 2, W * 2))
        res2 = self.res2(up1)
        res3 = self.res3(res2)
        
        up4  = self.up4(res3, output_size=(H * 4, W * 4))
        res5 = self.res5(up4)
        res6 = self.res6(res5)
        
        conv7 = self.conv7(res6)
        return conv7 + inputs
    
# FAC(以下代码源自本人博客[动态滤波器卷积在CV中的应用]中的相关分析)
def unfold_and_permute(tensor, kernel, stride=1, pad=-1):
    if pad < 0:
        pad = (kernel - 1) // 2
    tensor = F.pad(tensor, (pad, pad, pad, pad))
    tensor = tensor.unfold(2, kernel, stride)
    tensor = tensor.unfold(3, kernel, stride)
    N, C, H, W, _, _ = tensor.size()
    tensor = tensor.reshape(N, C, H, W, -1)
    tensor = tensor.permute(0, 2, 3, 1, 4)
    return tensor

def weight_permute_reshape(tensor, F, S2):
    N, C, H, W = tensor.size()
    tensor = tensor.permute(0, 2, 3, 1)
    tensor = tensor.reshape(N, H, W, F, S2)
    return tensor

# Filter_adaptive_convolution
def FAC(feat, filters, kernel_size):
    N, C, H, W = feat.size()
    pad = (kernel_size - 1) // 2
    feat = unfold_and_permute(feat, kernel_size, 1, pad)
    weight = weight_permute_reshape(filters, C, kernel_size**2)

    output = feat * weight
    output = output.sum(-1)
    output = output.permute(0,3,1,2)
    return output

# Architecture
class STFAN(nn.Module):
    def __init__(self, kernel_size=5):
        super(STFAN, self).__init__()
        filters = 128 * kernel_size ** 2
        self.ext = nn.Sequential(nn.Conv2d(9,32,3,1,1),
                                 ResBlock(32),
                                 ResBlock(32),
                                 nn.Conv2d(32,64,3,2,1),
                                 ResBlock(64),
                                 ResBlock(64),
                                 nn.Conv2d(64,128,3,2,1),
                                 ResBlock(128),
                                 ResBlock(128))
        self.align = nn.Sequential(nn.Conv2d(128,128,3,1,1),
                                   ResBlock(128),
                                   ResBlock(128),
                                   nn.Conv2d(128, filters, 1))
        self.deblur1 = nn.Conv2d(filters, 128, 1, 1)
        self.deblur2 = nn.Sequential(nn.Conv2d(256,128,3,1,1),
                                     ResBlock(128),
                                     ResBlock(128),
                                     nn.Conv2d(128, filters, 1))
        self.conv22 = nn.Conv2d(256, 128, 3, 1, 1)

    def forward(self, pre, cur, dpre):
        ext = self.ext(torch.cat([dpre, pre, cur], dim=1))
        align = self.align(ext)

        d1 = self.deblur1(align)
        deblur = self.deblur2(torch.cat([ext, d1], dim=1))
        return align, deblur


class Net(nn.Module):
    def __init__(self, kernel_size=5):
        super(Net, self).__init__()
        self.kernel_size = kernel_size
        self.feat = FeatureExtract()
        self.stfan = STFAN()
        self.fusion = nn.Conv2d(256, 128, 3, 1, 1)
        self.recon = ReconBlock()

    def forward(self, pre, cur, dpre, fpre):
        # Feature Extract.
        feat = self.feat(cur)

        # STFAN
        align, deblur = self.stfan(pre, cur, dpre)

        # Align and Deblur
        falign  = FAC(fpre, align, self.kernel_size)
        fdeblur = FAC(feat, deblur, self.kernel_size)
        fusion = torch.cat([falign, fdeblur], dim=1)

        # fpre for next
        fpre  = self.fusion(fusion)

        # Reconstruction
        out = self.recon(fusion, cur)
        return out, fpre
    
def demo():
    pre = torch.randn(4, 3, 64, 64)
    cur = torch.randn(4, 3, 64, 64)
    dpre = torch.randn(4, 3, 64, 64)
    fpre = torch.randn(4, 128, 16, 16)

    model = Net()
    model.eval()

    with torch.no_grad():
        output = model(pre, cur, dpre, fpre)
    print(output[0].size())
    print(output[1].size())

if __name__ == "__main__":
    demo()



 

本文章著作权归作者所有,任何形式的转载都请注明出处。更多动态滤波,图像质量,超分辨相关请关注我的专栏深度学习从入门到精通
审核编辑 黄昊宇

打开APP阅读更多精彩内容
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

全部0条评论

快来发表一下你的评论吧 !

×
20
完善资料,
赚取积分