如何在TensorFlow Lite Micro中添加自定义操作符(1)

描述

相信大家在部署嵌入式端的AI应用时,一定使用过TensorFlow Lite Micro,以下简称TFLm。TFLm 是专为微控制器和嵌入式设备设计的轻量级机器学习推理框架,它通过模块化的操作符系统来支持各种神经网络层的计算。也就是说,我们不仅可以使用内嵌的算子运算,还可以自己注册一个新的算子,更加的灵活。本期就将用两期的文章以 `reshape.cpp` 为例,详细说明如何在 TensorFlow Lite Micro 中添加一个新的操作符。

操作符注册不仅是模型推理的基础,更是优化性能、减少内存占用的关键环节。掌握这一机制,开发者可以更灵活地定制算子,满足特定硬件和应用需求。

在 TFLite Micro 中,每个操作符都需要经过以下几个关键步骤:

1. 内核实现:定义操作符的具体计算逻辑

2. 参数解析:从 FlatBuffer 格式中解析操作符参数

3. 操作符注册:将操作符注册到解析器中,使其可被模型调用

4. 内存管理:处理张量的内存分配和释放

操作符实现的核心组件

1. 文件结构说明

添加新操作符需要修改以下几个关键文件,每个文件都有其特定的作用:

micro/kernels/reshape.cpp #

 操作符的核心计算逻辑实现

micro/micro_mutable_op_resolver.h# 

可变操作符解析器,用于动态注册操作符

core/api/flatbuffer_conversions.h # 

FlatBuffer 参数解析函数的声明

core/api/flatbuffer_conversions.cpp # 

FlatBuffer 参数解析函数的具体实现

micro/all_ops_resolver.cpp # 

全局操作符解析器,包含所有支持的操作符

文件作用详解:

`micro/kernels/` 目录:

存放所有操作符的具体实现,每个操作符一个文件

`micro_mutable_op_resolver.h`:

提供灵活的操作符注册接口,允许用户选择性地添加操作符

`flatbuffer_conversions.*`:

处理模型文件中的参数解析,将 FlatBuffer 格式转换为 C++ 结构体

`all_ops_resolver.cpp`:

预定义了所有标准操作符的注册,适用于需要完整操作符支持的场景

2. 核心实现文件分析

 2.1 头文件引入

文件位置:`micro/kernels/reshape.cpp`

 

#include 
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/memory_helpers.h"
#include "tensorflow/lite/micro/micro_utils.h"

 

头文件说明:

`builtin_op_data.h`:包含所有内置操作符的参数结构体定义

`common.h`:TFLite 的基础数据类型和状态码定义

`tensor_ctypes.h`:张量数据类型相关的工具函数

`kernel_util.h`:操作符实现的通用工具函数

`op_macros.h`:操作符实现中常用的宏定义

`micro/kernels/kernel_util.h`:Micro 版本特有的内核工具函数

`memory_helpers.h`:内存管理相关的辅助函数

`micro_utils.h`:Micro 版本的通用工具函数

 2.2 命名空间和常量定义

 

namespace tflite {
namespace ops {
namespace micro {
namespace reshape {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;

 

命名空间说明:

`tflite::reshape`:四层命名空间确保了代码的组织性和避免命名冲突

常量定义:`kInputTensor` 和 `kOutputTensor` 定义了输入输出张量的索引,Reshape 操作只有一个输入和一个输出

 2.3 核心函数实现

ReshapeOutput 函数 - 形状计算逻辑

 

TfLiteStatus ReshapeOutput(TfLiteContext* context, TfLiteNode* node) {
  MicroContext* micro_context = GetMicroContext(context);
  // 获取输入和输出张量 - 使用临时分配避免持久内存占用
  TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, kInputTensor);
  TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, kOutputTensor);
  // 计算输入元素总数 - 用于验证 reshape 操作的合法性
  int num_input_elements = NumElements(input);
  TfLiteIntArray* output_shape = output->dims;
  // 处理特殊情况:-1 维度自动计算
  // TensorFlow 允许一个维度设为 -1,表示根据其他维度自动推断
  int num_output_elements = 1;
  int stretch_dim = -1;
  for (int i = 0; i < output_shape->size; ++i) {
    int value = output_shape->data[i];
    if (value == -1) {
      TF_LITE_ENSURE_EQ(context, stretch_dim, -1);  // 确保只有一个 -1 维度
      stretch_dim = i;
    } else {
      num_output_elements *= value;
    }
  }
  // 如果存在 -1 维度,自动计算其大小
  if (stretch_dim != -1) {
    TfLiteEvalTensor* output_eval = 
        tflite::micro::GetEvalOutput(context, node, kOutputTensor);
    TF_LITE_ENSURE_STATUS(tflite::micro::CreateWritableTensorDimsWithCopy(
        context, output, output_eval));
    output_shape = output->dims;  // 更新形状指针
    output_shape->data[stretch_dim] = num_input_elements / num_output_elements;
    num_output_elements *= output_shape->data[stretch_dim];
  }
  // 确保输入输出元素数量一致 - Reshape 不改变元素总数
  TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements);
  TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);  // 确保数据类型一致
  // 释放临时张量 - 避免内存泄漏
  micro_context->DeallocateTempTfLiteTensor(input);
  micro_context->DeallocateTempTfLiteTensor(output);
  return kTfLiteOk;
}

 

函数作用详解:

临时张量分配:使用 `AllocateTempInputTensor` 和 `AllocateTempOutputTensor` 获取张量信息,这些是临时分配,不会占用持久内存

形状验证:确保 reshape 操作的合法性,输入输出元素总数必须相等

自动维度推断:处理 -1 维度的特殊情况,这是 TensorFlow 的标准特性

内存管理:及时释放临时张量,这在内存受限的微控制器环境中非常重要

Prepare 函数-操作符准备阶段:

 

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
  // 验证输入输出数量 - Reshape 可以有 1 个或 2 个输入(第二个输入是可选的形状参数)
  TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2);
  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);  // 只有一个输出
  // 执行输出形状重塑 - 在准备阶段确定最终的输出形状
  TF_LITE_ENSURE_EQ(context, ReshapeOutput(context, node), kTfLiteOk);
  return kTfLiteOk;
}

 

Prepare函数说明:

输入验证:Reshape 操作支持 1-2 个输入,第二个输入是可选的形状张量

形状计算:在准备阶段就确定输出形状,避免在执行阶段重复计算

错误检查:使用 `TF_LITE_ENSURE` 宏进行参数验证,失败时会返回错误状态

Eval 函数-操作符执行阶段:

 

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
  // 获取输入输出张量 - 使用 EvalTensor 进行实际计算
  const TfLiteEvalTensor* input = 
      tflite::GetEvalInput(context, node, kInputTensor);
  TfLiteEvalTensor* output = 
      tflite::GetEvalOutput(context, node, kOutputTensor);
  // 计算输入数据大小 - 需要拷贝的字节数
  size_t input_bytes;
  TF_LITE_ENSURE_STATUS(TfLiteTypeSizeOf(input->type, &input_bytes));
  input_bytes *= ElementCount(*input->dims);
  // 执行数据拷贝(如果不是原地操作)
  // 原地操作:输入输出使用同一块内存,无需拷贝
  if (input->data.raw != output->data.raw) {
    memcpy(output->data.raw, input->data.raw, input_bytes);
  }
  return kTfLiteOk;
}

 

Eval函数说明:

EvalTensor 使用:在执行阶段使用 `TfLiteEvalTensor`,它包含实际的数据指针

原地操作优化:检查输入输出是否共享内存,避免不必要的数据拷贝

内存拷贝:Reshape 操作本质上只是改变数据的解释方式,不改变数据内容

 2.4 操作符注册函数

 

TfLiteRegistration_V1 Register_RESHAPE() {
  return tflite::micro::RegisterOp(nullptr, reshape::Prepare, reshape::Eval);
}
}  // namespace reshape
}  // namespace micro
}  // namespace ops
}  // namespace tflite

 

注册函数说明:

RegisterOp 函数:创建操作符注册结构体,包含初始化、准备和执行函数指针

nullptr 参数:第一个参数是初始化函数,Reshape 不需要特殊初始化,所以传入 nullptr

函数指针:传入 Prepare 和 Eval 函数指针,框架会在适当时机调用这些函数

 3. 参数解析实现

 3.1 解析函数声明

文件位置:`core/api/flatbuffer_conversions.h`

 

TfLiteStatus ParseReshape(const Operator* op, ErrorReporter* error_reporter,
                          BuiltinDataAllocator* allocator, void builtin_data);

 

声明说明:

Operator* op:来自 FlatBuffer 的操作符定义,包含所有参数信息

ErrorReporter:用于报告解析过程中的错误

BuiltinDataAllocator:专用的内存分配器,用于分配参数结构体

builtin_data:输出参数,指向解析后的参数结构体

 3.2 解析函数实现

文件位置:`core/api/flatbuffer_conversions.cpp`

 

TfLiteStatus ParseReshape(const Operator* op, ErrorReporter* error_reporter,
                          BuiltinDataAllocator* allocator,
                          void builtin_data) ;

 

解析函数详解:

参数验证:`CheckParsePointerParams` 确保所有指针参数有效

安全分配器:`SafeBuiltinDataAllocator` 提供异常安全的内存分配

FlatBuffer 解析:从序列化的模型文件中提取 reshape 参数

格式转换:将 FlatBuffer 格式转换为 TFLite 内部使用的 C 结构体格式

所有权转移:使用 `release()` 将参数结构体的所有权转移给框架

 3.3 在解析开关中添加对应的case

文件位置:`core/api/flatbuffer_conversions.cpp`

在 `ParseOpData` 函数的 switch 语句中添加:

 

case BuiltinOperator_RESHAPE: {
  return ParseReshape(op, error_reporter, allocator, builtin_data);
}

 

开关语句说明:

这个 switch 语句是 TFLite 参数解析的核心分发机制

根据操作符类型调用相应的解析函数

`BuiltinOperator_RESHAPE` 是在FlatBuffer schema中定义的枚举值

通过本指南,我们深入了解了 TensorFlow Lite Micro 的操作符注册机制,包括其设计理念、实现方式以及在嵌入式场景中的重要性。

未来,随着边缘计算和微控制器 AI 的快速发展,理解并运用这些底层机制将成为构建高效、可扩展 AI 系统的核心能力。建议读者在实践中尝试自定义算子注册,并结合实际项目进行优化,以真正发挥 TensorFlow Lite Micro 的潜力。

打开APP阅读更多精彩内容
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

全部0条评论

快来发表一下你的评论吧 !

×
20
完善资料,
赚取积分