如何进行tflite模型量化

描述

开发AI应用分享来到了本系列最后一期分享。

在windows上,如果我们按照上一期的方式安装了tflite2pb,是不能直接运行的。因为命令行工具是为linux编译的。那我们的广大windows用户就无缘了吗?那可未必,小编给大家带来一个好办法,当然不仅仅适用于当前这个tflite2pb工具,话不多说,开始:

首先让我们找到tflite2pb的安装路径:

WINDOWS

打开tflite2tensorflow.py文件,并定位到5641行:

WINDOWS

我们可以看到,这是函数的主入口,并且添加了对于命令行参数的解析,而且既然是def定义的函数,那我们就可以通过import来进行导入。那接下来小编就要利用这个函数做一些文章,首先在我们编写的转换函数中导入这个叫做main的函数:

 

fromtflite2tensorflow.tflite2tensorflow import main

 

现在已经导入了进来,接下来是怎么传入main函数需要处理的参数,换句话来说,main函数实际上是处理系统命令行参数,那我们需要做的就是伪造一个系统命令行参数,那简单了,首先导入sys模块,然后开始系统参数伪造:

 

sys.argv = ['main.py', f'--model_path={tflite_model_name}.tflite', r'--flatc_path=flatc.exe', '--schema_path=schema.fbs', '--output_pb']

 

这里通过直接给sys.argv参数赋值,实际上相当于像系统命令行传入了参数,接下来直接调用main:

 

main()

 

果不其然,成功运行,和在linux上运行效果一致,我们也获得了saved_model文件夹以及模型:

WINDOWS

下面就是进行模型的量化:

 

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model("saved_model")
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.representative_dataset = gen_representative_data_gen(represent_data)
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_model = converter.convert()
# Save the i8 model.
with open(f"{tflite_model_name}_i8_opt.tflite", "wb") as f:
   f.write(tflite_model)

 

熟悉tflite模型量化的伙伴应该是很熟悉了吧?首先声明converter导入我们刚才生成的pb格式模型的文件夹,接下来指定量化类型为:

 

tf.lite.OpsSet.TFLITE_BUILTINS_INT8

 

指定代表性数据集:

 

converter.representative_dataset = gen_representative_data_gen(represent_data)

 

为了快速验证模型,我们可以先随机出来一些数据来指导模型量化:

 

defgen_representative_data_gen(datas_path):
   datas = np.load(datas_path)
   datas = (np.random.randint(0, 255, size=(10, 1, 192,192,3), dtype='uint8') / 255).astype("float32")
   defrepresentative_data_gen(samples=datas):
       for sample in samples:
           yield [sample]
   returnrepresentative_data_gen

 

并且指定输入输出格式:

 

converter.inference_input_type = tf.int8
       converter.inference_output_type = tf.int8

 

最终保存量化好的模型:

 

with open(f"{tflite_model_name}_i8_opt.tflite", "wb") as f:
f.write(tflite_model)

 

当然,我们还可以打印出来模型的算子信息:

 

intepreter = tf.lite.Interpreter(model_path=f"{tflite_model_name}_i8_opt.tflite")
   op_names = sorted(set([x['op_name'] for x inintepreter._get_ops_details()]))
   print("len ops: ", len(op_names))
   print(op_names)

 

好了,至此,记站在巨人的肩膀上开发AI应用就到此完结了,此次系列分享给大家分享了小编的一次项目开发历程,尤其是涉及到如何将网上找到的,不包含推理代码的模型,如何一步步的进行分析,并最终转换为可以为我们所用的模型,希望能够帮助大家!

打开APP阅读更多精彩内容
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

全部0条评论

快来发表一下你的评论吧 !

×
20
完善资料,
赚取积分