一旦我们选择了一个架构并设置了我们的超参数,我们就进入训练循环,我们的目标是找到最小化损失函数的参数值。训练后,我们将需要这些参数来进行未来的预测。此外,我们有时会希望提取参数以在其他上下文中重用它们,将我们的模型保存到磁盘以便它可以在其他软件中执行,或者进行检查以期获得科学理解。
大多数时候,我们将能够忽略参数声明和操作的具体细节,依靠深度学习框架来完成繁重的工作。然而,当我们远离具有标准层的堆叠架构时,我们有时需要陷入声明和操作参数的困境。在本节中,我们将介绍以下内容:
-
访问用于调试、诊断和可视化的参数。
-
跨不同模型组件共享参数。
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
我们首先关注具有一个隐藏层的 MLP。
torch.Size([2, 1])
(2, 1)
(2, 1)
6.2.1. 参数访问
让我们从如何从您已知的模型中访问参数开始。
当通过类定义模型时Sequential
,我们可以首先通过索引模型来访问任何层,就好像它是一个列表一样。每个层的参数都方便地位于其属性中。
When a model is defined via the Sequential
class, we can first access any layer by indexing into the model as though it were a list. Each layer’s parameters are conveniently located in its attribute.
Flax and JAX decouple the model and the parameters as you might have observed in the models defined previously. When a model is defined via the Sequential
class, we first need to initialize the network to generate the parameters dictionary. We can access any layer’s parameters through the keys of this dictionary.
When a model is defined via the Sequential
class, we can first access any layer by indexing into the model as though it were a list. Each layer’s parameters are conveniently located in its attribute.
我们可以如下检查第二个全连接层的参数。
OrderedDict([('weight',
tensor([[-0.2523, 0.2104, 0.2189, -0.0395, -0.0590, 0.3360, -0.0205, -0.1507]])),
('bias', tensor([0.0694]))])
dense1_ (
Parameter dense1_weight (shape=(1, 8), dtype=float32)
Parameter dense1_bias (shape=(1,), dtype=float32)
)
FrozenDict({
kernel: Array([[-0.20739523],
[ 0.16546965],
[-0.03713543],
[-0.04860032],
[-0.2102929 ],
[ 0.163712 ],
[ 0.27240783],
[-0.4046879 ]], dtype=float32),
bias: Array([0.], dtype=float32),
})
我们可以看到这个全连接层包含两个参数,分别对应于该层的权重和偏差。
6.2.1.1. 目标参数
请注意,每个参数都表示为参数类的一个实例。要对参数做任何有用的事情,我们首先需要访问基础数值。做这件事有很多种方法。有些更简单,有些则更通用。以下代码从返回参数类实例的第二个神经网络层中提取偏差,并进一步访问该参数的值。
(torch.nn.parameter.Parameter, tensor([0.0694]))
参数是复杂的对象,包含值、梯度和附加信息。这就是为什么我们需要显式请求该值。
除了值之外,每个参数还允许我们访问梯度。因为我们还没有为这个网络调用反向传播,所以它处于初始状态。
True
(mxnet.gluon.parameter.Parameter, array([0.]))
Parameters are complex objects, containing values, gradients, and additional information. That is why we need to request the value explicitly.
In addition to the value, each parameter also allows us to access the gradient. Because we have not invoked backpropagation for this network yet, it is in its initial state.
array([[0., 0., 0., 0., 0., 0., 0., 0.]])
(jaxlib.xla_extension.Array, Array([0.], dtype=float32))
Unlike the other frameworks, JAX does not keep a track of the gradients over the neural network parameters, instead the parameters and the network are decoupled. It allows the user to express their computation as a Python function, and use the grad
transformation for the same purpose.