在本节中,我们将讨论优化与深度学习之间的关系以及在深度学习中使用优化的挑战。对于一个深度学习问题,我们通常会先定义一个 损失函数。一旦我们有了损失函数,我们就可以使用优化算法来尝试最小化损失。在优化中,损失函数通常被称为优化问题的目标函数。按照传统和约定,大多数优化算法都与最小化有关。如果我们需要最大化目标,有一个简单的解决方案:只需翻转目标上的标志。
12.1.1。优化目标
尽管优化为深度学习提供了一种最小化损失函数的方法,但从本质上讲,优化和深度学习的目标是根本不同的。前者主要关注最小化目标,而后者关注在给定有限数据量的情况下找到合适的模型。在 第 3.6 节中,我们详细讨论了这两个目标之间的区别。例如,训练误差和泛化误差通常不同:由于优化算法的目标函数通常是基于训练数据集的损失函数,因此优化的目标是减少训练误差。然而,深度学习(或更广泛地说,统计推断)的目标是减少泛化误差。为了完成后者,除了使用优化算法来减少训练误差外,我们还需要注意过度拟合。
为了说明上述不同的目标,让我们考虑经验风险和风险。如第 4.7.3.1 节所述 ,经验风险是训练数据集的平均损失,而风险是整个数据群的预期损失。下面我们定义两个函数:风险函数f
和经验风险函数g
。假设我们只有有限数量的训练数据。结果,这里g
不如 平滑f
。
下图说明了训练数据集上经验风险的最小值可能与风险的最小值(泛化误差)位于不同的位置。
def annotate(text, xy, xytext): #@save
d2l.plt.gca().annotate(text, xy=xy, xytext=xytext,
arrowprops=dict(arrowstyle='->'))
x = torch.arange(0.5, 1.5, 0.01)
d2l.set_figsize((4.5, 2.5))
d2l.plot(x, [f(x), g(x)], 'x', 'risk')
annotate('min of\nempirical risk', (1.0, -1.2), (0.5, -1.1))
annotate('min of risk', (1.1, -1.05), (0.95, -0.5))
def annotate(text, xy, xytext): #@save
d2l.plt.gca().annotate(text, xy=xy, xytext=xytext,
arrowprops=dict(arrowstyle='->'))
x = np.arange(0.5, 1.5, 0.01)
d2l.set_figsize((4.5, 2.5))
d2l.plot(x, [f(x), g(x)], 'x', 'risk')
annotate('min of\nempirical risk', (1.0, -1.2), (0.5, -1.1))
annotate('min of risk', (1.1, -1.05), (0.95, -0.5))
def annotate(text, xy, xytext): #@save
d2l.plt.gca().annotate(text, xy=xy, xytext=xytext,
arrowprops=dict(arrowstyle='->'))
x = tf.range(0.5, 1.5, 0.01)
d2l.set_figsize((4.5, 2.5))
d2l.plot(x, [f(x), g(x)], 'x', 'risk')
annotate('min of\nempirical risk', (1.0, -1.2), (0.5, -1.1))
annotate('min of risk', (1.1, -1.05), (0.95, -0.5))
12.1.2。深度学习中的优化挑战
在本章中,我们将特别关注优化算法在最小化目标函数方面的性能,而不是模型的泛化误差。在 3.1 节中,我们区分了优化问题中的解析解和数值解。在深度学习中,大多数目标函数都很复杂,没有解析解。相反,我们必须使用数值优化算法。本章的优化算法都属于这一类。
深度学习优化有很多挑战。一些最令人烦恼的是局部最小值、鞍点和梯度消失。让我们来看看它们。
12.1.2.1。局部最小值
对于任何目标函数f(x), 如果值f(x)在 x小于的值f(x)在附近的任何其他点x, 然后f(x)可能是局部最小值。如果值f(x)在x是整个域内目标函数的最小值,则f(x)是全局最小值。
例如,给定函数
我们可以逼近这个函数的局部最小值和全局最小值。
x = np.arange(-1.0, 2.0, 0.01)
d2l.plot(x, [f(x), ], 'x', 'f(x)')
annotate('local minimum', (-0.3, -0.25), (-0.77, -1.0))
annotate('global minimum',