×

PyTorch教程6.2之参数管理

消耗积分:0 | 格式:pdf | 大小:0.13 MB | 2023-06-05

廉鼎琮

分享资料个

一旦我们选择了一个架构并设置了我们的超参数,我们就进入训练循环,我们的目标是找到最小化损失函数的参数值。训练后,我们将需要这些参数来进行未来的预测。此外,我们有时会希望提取参数以在其他上下文中重用它们,将我们的模型保存到磁盘以便它可以在其他软件中执行,或者进行检查以期获得科学理解。

大多数时候,我们将能够忽略参数声明和操作的具体细节,依靠深度学习框架来完成繁重的工作。然而,当我们远离具有标准层的堆叠架构时,我们有时需要陷入声明和操作参数的困境。在本节中,我们将介绍以下内容:

  • 访问用于调试、诊断和可视化的参数。

  • 跨不同模型组件共享参数。

import torch
from torch import nn
from mxnet import init, np, npx
from mxnet.gluon import nn

npx.set_np()
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import tensorflow as tf

我们首先关注具有一个隐藏层的 MLP。

net = nn.Sequential(nn.LazyLinear(8),
          nn.ReLU(),
          nn.LazyLinear(1))

X = torch.rand(size=(2, 4))
net(X).shape
torch.Size([2, 1])
net = nn.Sequential()
net.add(nn.Dense(8, activation='relu'))
net.add(nn.Dense(1))
net.initialize() # Use the default initialization method

X = np.random.uniform(size=(2, 4))
net(X).shape
(2, 1)
net = nn.Sequential([nn.Dense(8), nn.relu, nn.Dense(1)])

X = jax.random.uniform(d2l.get_key(), (2, 4))
params = net.init(d2l.get_key(), X)
net.apply(params, X).shape
(2, 1)
net = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(4, activation=tf.nn.relu),
  tf.keras.layers.Dense(1),
])

X = tf.random.uniform((2, 4))
net(X).shape
TensorShape([2, 1])

6.2.1. 参数访问

让我们从如何从您已知的模型中访问参数开始。

当通过类定义模型时Sequential,我们可以首先通过索引模型来访问任何层,就好像它是一个列表一样。每个层的参数都方便地位于其属性中。

When a model is defined via the Sequential class, we can first access any layer by indexing into the model as though it were a list. Each layer’s parameters are conveniently located in its attribute.

Flax and JAX decouple the model and the parameters as you might have observed in the models defined previously. When a model is defined via the Sequential class, we first need to initialize the network to generate the parameters dictionary. We can access any layer’s parameters through the keys of this dictionary.

When a model is defined via the Sequential class, we can first access any layer by indexing into the model as though it were a list. Each layer’s parameters are conveniently located in its attribute.

我们可以如下检查第二个全连接层的参数。

net[2].state_dict()
OrderedDict([('weight',
       tensor([[-0.2523, 0.2104, 0.2189, -0.0395, -0.0590, 0.3360, -0.0205, -0.1507]])),
       ('bias', tensor([0.0694]))])
net[1].params
dense1_ (
 Parameter dense1_weight (shape=(1, 8), dtype=float32)
 Parameter dense1_bias (shape=(1,), dtype=float32)
)
params['params']['layers_2']
FrozenDict({
  kernel: Array([[-0.20739523],
      [ 0.16546965],
      [-0.03713543],
      [-0.04860032],
      [-0.2102929 ],
      [ 0.163712 ],
      [ 0.27240783],
      [-0.4046879 ]], dtype=float32),
  bias: Array([0.], dtype=float32),
})
net.layers[2].weights
[<tf.Variable 'dense_1/kernel:0' shape=(4, 1) dtype=float32, numpy=
 array([[-0.52124995],
    [-0.22314149],
    [ 0.20780373],
    [ 0.6839919 ]], dtype=float32)>,
 <tf.Variable 'dense_1/bias:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>]

我们可以看到这个全连接层包含两个参数,分别对应于该层的权重和偏差。

6.2.1.1. 目标参数

请注意,每个参数都表示为参数类的一个实例。要对参数做任何有用的事情,我们首先需要访问基础数值。做这件事有很多种方法。有些更简单,有些则更通用。以下代码从返回参数类实例的第二个神经网络层中提取偏差,并进一步访问该参数的值。

type(net[2].bias), net[2].bias.data
(torch.nn.parameter.Parameter, tensor([0.0694]))

参数是复杂的对象,包含值、梯度和附加信息。这就是为什么我们需要显式请求该值。

除了值之外,每个参数还允许我们访问梯度。因为我们还没有为这个网络调用反向传播,所以它处于初始状态。

net[2].weight.grad == None
True
type(net[1].bias), net[1].bias.data()
(mxnet.gluon.parameter.Parameter, array([0.]))

Parameters are complex objects, containing values, gradients, and additional information. That is why we need to request the value explicitly.

In addition to the value, each parameter also allows us to access the gradient. Because we have not invoked backpropagation for this network yet, it is in its initial state.

net[1].weight.grad()
array([[0., 0., 0., 0., 0., 0., 0., 0.]])
bias = params['params']['layers_2']['bias']
type(bias), bias
(jaxlib.xla_extension.Array, Array([0.], dtype=float32))

Unlike the other frameworks, JAX does not keep a track of the gradients over the neural network parameters, instead the parameters and the network are decoupled. It allows the user to express their computation as a Python function, and use the grad transformation for the same purpose.

type(net.layers[2].weights[1]), tf.convert_to_tensor(net.layers[2].weights[1])
(tensorflow.python.ops.resource_variable_ops.ResourceVariable,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>)

6.2.1.2. 一次所有参数

当我们需要对所有参数执行操作时,一个一个地访问它们会变得乏味。当我们使用更复杂的模块(例如,嵌套模块)时,情况会变得特别笨拙,因为我们需要递归遍历整个树以提取每个子模块的参数。下面我们演示访问所有层的参数。

[(name, param.shape) for name, param in net.named_parameters()]
[('0.weight', torch

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

评论(0)
发评论

下载排行榜

全部0条评论

快来发表一下你的评论吧 !