深入GRU:解锁模型测试新维度

描述

之前带大家一起使用Keras训练了一个GRU模型,并使用mnist的手写字体数据集进行了验证。本期小编将继续带来一篇扩展,即GRU模型的测试方法。尽管我们将其当作和CNN类似的方式,一次性传给他固定长度的数据,但在具体实现上来说,还是另有门道的。让我们慢慢讲来。

首先回顾前面我们最终训练并导出的测试模型:

Gru

注意红色标注的位置,这就是一个典型的GRU节点:

Gru

模型的输入是28*28,代表的含义是:时间步*特征维度,简单来说,就是一次性送入模型多久的数据,即时间步实际上是一个时间单位。例如,我们想测试1s的数据进行检测。那么可以将其提取出来10*28的特征向量,那么每一个特征代表的就是1s/10即100ms的音频特征。了解了时间步,再回到模型本身,这里就是其中一种模型推理形式。即一次性将所有的数据都送进去,即1s对应的特征数据。然后计算得出一个结果。好处是:所见所得,和CNN类似,缺点是:必须等待1s数据,且循环time_step次。

那有没有替代方案呢?当然,那就是小编要提到的另一种,我们在导出模型时候将time_step设置为1,并且设置stateful=True,同时将time_step=N的模型权重设置到新模型上。这里的stateful=True要注意,因为我们将之前连续的time_step拆成了独立的,因此需要让模型记住前一次的中间状态。

可能这里大家有些疑问,为什么两个模型time_step不同,权重竟然通用。这就要说GRU模型的特殊性了,我们刚才看到的被展开的小GRU节点,其实是权重共享的。也就是说,不管展开多少次,他们的权重不会变(这个读者可以打开模型自行查看验证)。因此,就可以用如下代码生成新模型,并设置权重:

# 构建新模型
new_model =  Sequential()
new_model.add(GRU(128, batch_input_shape=(1, 1, 28), unroll=True, stateful=True))
new_model.add(Dense(10, activation='softmax'))
new_model.set_weights(model.get_weights())

让我们看看模型的样子:

Gru

是不是看起来非常清爽,请注意右下角那个AssignVariable,这个就是为了保存当前状态,在下一次推理可以直接使用上一次的状态。需要注意的是,由于模型输入变成了一个time step,即1*28,在送入模型前,要注意一下。后续处理部分,依旧是FullyConnected+Softmax的形式,其他没有改变,照常即可。

至此,所有关于GRU模型的介绍以及使用就全部讲完了。MCU端的部署要靠大家自行体验了,因为模型本身实际上可以使用和CNN一样的推理方案,只是内部结构不同而已,希望对大家有所帮助!      

恩智浦致力于打造安全的连接和基础设施解决方案,为智慧生活保驾护航。

打开APP阅读更多精彩内容
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

全部0条评论

快来发表一下你的评论吧 !

×
20
完善资料,
赚取积分