利用PyTorch实现NeRF代码详解

描述

作者:大森林 | 来源:3DCV

1. NeRF定义

神经辐射场(NeRF)是一种利用神经网络来表示和渲染复杂的三维场景的方法。它可以从一组二维图片中学习出一个连续的三维函数,这个函数可以给出空间中任意位置和方向上的颜色和密度。通过体积渲染的技术,NeRF可以从任意视角合成出逼真的图像,包括透明和半透明物体,以及复杂的光线传播效果。

2. NeRF优势

NeRF模型相比于其他新的视图合成和场景表示方法有以下几个优势:

1)NeRF不需要离散化的三维表示,如网格或体素,因此可以避免模型精度和细节程度受到限制。NeRF也可以自适应地处理不同形状和大小的场景,而不需要人工调整参数。

2)NeRF使用位置编码的方式将位置和角度信息映射到高频域,使得网络能够更好地捕捉场景的细微结构和变化。NeRF还使用视角相关的颜色预测,能够生成不同视角下不同的光照效果。

3)NeRF使用分段随机采样的方式来近似体积渲染的积分,这样可以保证采样位置的连续性,同时避免网络过拟合于离散点的信息。NeRF还使用多层级体素采样的技巧,以提高渲染效率和质量。

3. NeRF实现步骤

1)定义一个全连接的神经网络,它的输入是空间位置和视角方向,输出是颜色和密度。

2)使用位置编码的方式将输入映射到高频域,以便网络能够捕捉细微的结构和变化。

3)使用分段随机采样的方式从每条光线上采样一些点,然后用神经网络预测这些点的颜色和密度。

4)使用体积渲染的公式计算每条光线上的颜色和透明度,作为最终的图像输出。

5)使用渲染损失函数来优化神经网络的参数,使得渲染的图像与输入的图像尽可能接近。

 

import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义一个全连接的神经网络,它的输入是空间位置和视角方向,输出是颜色和密度。
class NeRF(nn.Module):
    def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4]):
        super().__init__()
        # 定义位置编码后的位置信息的线性层,如果层数在skips列表中,则将原始位置信息与隐藏层拼接
        self.pts_linears = nn.ModuleList(
            [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in skips else nn.Linear(W + input_ch, W) for i in range(D-1)])
        # 定义位置编码后的视角方向信息的线性层
        self.views_linears = nn.ModuleList([nn.Linear(W + input_ch_views, W//2)] + [nn.Linear(W//2, W//2) for i in range(1)])
        # 定义特征向量的线性层
        self.feature_linear = nn.Linear(W//2, W)
        # 定义透明度(alpha)值的线性层
        self.alpha_linear = nn.Linear(W, 1)
        # 定义RGB颜色的线性层
        self.rgb_linear = nn.Linear(W + input_ch_views, 3)

    def forward(self, x):
        # x: (B, input_ch + input_ch_views)
        # 提取位置和视角方向信息
        p = x[:, :3] # (B, 3)
        d = x[:, 3:] # (B, 3)

        # 对输入进行位置编码,将低频信号映射到高频域
        p = positional_encoding(p) # (B, input_ch)
        d = positional_encoding(d) # (B, input_ch_views)

        # 将位置信息输入网络
        h = p
        for i, l in enumerate(self.pts_linears):
            h = l(h)
            h = F.relu(h)
            if i in skips:
                h = torch.cat([h, p], -1) # 如果层数在skips列表中,则将原始位置信息与隐藏层拼接

        # 将视角方向信息与隐藏层拼接,并输入网络
        h = torch.cat([h, d], -1)
        for i, l in enumerate(self.views_linears):
            h = l(h)
            h = F.relu(h)

        # 预测特征向量和透明度(alpha)值
        feature = self.feature_linear(h) # (B, W)
        alpha = self.alpha_linear(feature) # (B, 1)
        
        # 使用特征向量和视角方向信息预测RGB颜色
        rgb = torch.cat([feature, d], -1) 
        rgb = self.rgb_linear(rgb) # (B, 3)

        return torch.cat([rgb, alpha], -1) # (B, 4)

# 定义位置编码函数
def positional_encoding(x):
    # x: (B, C)
    B, C = x.shape
    L = int(C // 2) # 计算位置编码的长度
    freqs = torch.logspace(0., L - 1, steps=L).to(x.device) * math.pi # 计算频率系数,呈指数增长
    freqs = freqs[None].repeat(B, 1) # (B, L)
    x_pos_enc_low = torch.sin(x[:, :L] * freqs) # 对前一半的输入进行正弦变换,得到低频部分 (B, L)
    x_pos_enc_high = torch.cos(x[:, :L] * freqs) # 对前一半的输入进行余弦变换,得到高频部分 (B, L)
    x_pos_enc = torch.cat([x_pos_enc_low, x_pos_enc_high], dim=-1) # 将低频和高频部分拼接,得到位置编码后的输入 (B, C)
    return x_pos_enc

# 定义体积渲染函数
def volume_rendering(rays_o, rays_d, model):
    # rays_o: (B, 3), 每条光线的起点
    # rays_d: (B, 3), 每条光线的方向
    B = rays_o.shape[0]

    # 在每条光线上采样一些点
    near, far = 0., 1. # 近平面和远平面
    N_samples = 64 # 每条光线的采样数
    t_vals = torch.linspace(near, far, N_samples).to(rays_o.device) # (N_samples,)
    t_vals = t_vals.expand(B, N_samples) # (B, N_samples)
    z_vals = near * (1. - t_vals) + far * t_vals # 计算每个采样点的深度值 (B, N_samples)
    z_vals = z_vals.unsqueeze(-1) # (B, N_samples, 1)
    pts = rays_o.unsqueeze(1) + rays_d.unsqueeze(1) * z_vals # 计算每个采样点的空间位置 (B, N_samples, 3)

    # 将采样点和视角方向输入网络
    pts_flat = pts.reshape(-1, 3) # (B*N_samples, 3)
    rays_d_flat = rays_d.unsqueeze(1).expand(-1, N_samples, -1).reshape(-1, 3) # (B*N_samples, 3)
    x_flat = torch.cat([pts_flat, rays_d_flat], -1) # (B*N_samples, 6)
    y_flat = model(x_flat) # (B*N_samples, 4)
    y = y_flat.reshape(B, N_samples, 4) # (B, N_samples, 4)

    # 提取RGB颜色和透明度(alpha)值
    rgb = y[..., :3] # (B, N_samples, 3)
    alpha = y[..., 3] # (B, N_samples)

    # 计算每个采样点的权重
    dists = torch.cat([z_vals[..., 1:] - z_vals[..., :-1], torch.tensor([1e10]).to(z_vals.device).expand(B, 1)], -1) # 计算相邻采样点之间的距离,最后一个距离设为很大的值 (B, N_samples)
    alpha = 1. - torch.exp(-alpha * dists) # 计算每个采样点的不透明度,即1减去透明度的指数衰减 (B, N_samples)
    weights = alpha * torch.cumprod(torch.cat([torch.ones((B, 1)).to(alpha.device), 1. - alpha + 1e-10], -1), -1)[:, :-1] # 计算每个采样点的权重,即不透明度乘以之前所有采样点的透明度累积积,最后一个权重设为0 (B, N_samples)

    # 计算每条光线的最终颜色和透明度
    rgb_map = torch.sum(weights.unsqueeze(-1) * rgb, -2) # 加权平均每个采样点的RGB颜色,得到每条光线的颜色 (B, 3)
    depth_map = torch.sum(weights * z_vals.squeeze(-1), -1) # 加权平均每个采样点的深度值,得到每条光线的深度 (B,)
    acc_map = torch.sum(weights, -1) # 累加每个采样点的权重,得到每条光线的不透明度 (B,)
    
    return rgb_map, depth_map, acc_map

# 定义渲染损失函数
def rendering_loss(rgb_map_pred, rgb_map_gt):
    return ((rgb_map_pred - rgb_map_gt)**2).mean() # 计算预测的颜色与真实颜色之间的均方误差

 

综上所述,本代码实现了NeRF的核心结构,具体实现内容包括以下四个部分。

1)定义了NeRF网络结构,包含位置编码和多层全连接网络,输入是位置和视角,输出是颜色和密度。

2)实现了位置编码函数,通过正弦和余弦变换引入高频信息。

3)实现了体积渲染函数,在光线上采样点,查询NeRF网络预测颜色和密度,然后通过加权平均实现整体渲染。

4)定义了渲染损失函数,计算预测颜色和真实颜色的均方误差。

当然,本方案只是实现NeRF的一个基础方案,更多的细节还需要进行优化。

当然,为了方便下载,我们已经将上述两个源代码打包好了。

  审核编辑:汤梓红

打开APP阅读更多精彩内容
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

全部0条评论

快来发表一下你的评论吧 !

×
20
完善资料,
赚取积分