相信大家在部署嵌入式端的AI应用时,一定使用过TensorFlow Lite Micro,以下简称TFLm。TFLm 是专为微控制器和嵌入式设备设计的轻量级机器学习推理框架,它通过模块化的操作符系统来支持各种神经网络层的计算。也就是说,我们不仅可以使用内嵌的算子运算,还可以自己注册一个新的算子,更加的灵活。本期就将用两期的文章以 `reshape.cpp` 为例,详细说明如何在 TensorFlow Lite Micro 中添加一个新的操作符。
操作符注册不仅是模型推理的基础,更是优化性能、减少内存占用的关键环节。掌握这一机制,开发者可以更灵活地定制算子,满足特定硬件和应用需求。
在 TFLite Micro 中,每个操作符都需要经过以下几个关键步骤:
1. 内核实现:定义操作符的具体计算逻辑
2. 参数解析:从 FlatBuffer 格式中解析操作符参数
3. 操作符注册:将操作符注册到解析器中,使其可被模型调用
4. 内存管理:处理张量的内存分配和释放
操作符实现的核心组件
1. 文件结构说明
添加新操作符需要修改以下几个关键文件,每个文件都有其特定的作用:
micro/kernels/reshape.cpp #
操作符的核心计算逻辑实现
micro/micro_mutable_op_resolver.h#
可变操作符解析器,用于动态注册操作符
core/api/flatbuffer_conversions.h #
FlatBuffer 参数解析函数的声明
core/api/flatbuffer_conversions.cpp #
FlatBuffer 参数解析函数的具体实现
micro/all_ops_resolver.cpp #
全局操作符解析器,包含所有支持的操作符
文件作用详解:
`micro/kernels/` 目录:
存放所有操作符的具体实现,每个操作符一个文件
`micro_mutable_op_resolver.h`:
提供灵活的操作符注册接口,允许用户选择性地添加操作符
`flatbuffer_conversions.*`:
处理模型文件中的参数解析,将 FlatBuffer 格式转换为 C++ 结构体
`all_ops_resolver.cpp`:
预定义了所有标准操作符的注册,适用于需要完整操作符支持的场景
2. 核心实现文件分析
2.1 头文件引入
文件位置:`micro/kernels/reshape.cpp`
#include#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/op_macros.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" #include "tensorflow/lite/micro/memory_helpers.h" #include "tensorflow/lite/micro/micro_utils.h"
头文件说明:
`builtin_op_data.h`:包含所有内置操作符的参数结构体定义
`common.h`:TFLite 的基础数据类型和状态码定义
`tensor_ctypes.h`:张量数据类型相关的工具函数
`kernel_util.h`:操作符实现的通用工具函数
`op_macros.h`:操作符实现中常用的宏定义
`micro/kernels/kernel_util.h`:Micro 版本特有的内核工具函数
`memory_helpers.h`:内存管理相关的辅助函数
`micro_utils.h`:Micro 版本的通用工具函数
2.2 命名空间和常量定义
namespace tflite {
namespace ops {
namespace micro {
namespace reshape {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
命名空间说明:
`tflite::reshape`:四层命名空间确保了代码的组织性和避免命名冲突
常量定义:`kInputTensor` 和 `kOutputTensor` 定义了输入输出张量的索引,Reshape 操作只有一个输入和一个输出
2.3 核心函数实现
ReshapeOutput 函数 - 形状计算逻辑
TfLiteStatus ReshapeOutput(TfLiteContext* context, TfLiteNode* node) {
MicroContext* micro_context = GetMicroContext(context);
// 获取输入和输出张量 - 使用临时分配避免持久内存占用
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, kInputTensor);
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, kOutputTensor);
// 计算输入元素总数 - 用于验证 reshape 操作的合法性
int num_input_elements = NumElements(input);
TfLiteIntArray* output_shape = output->dims;
// 处理特殊情况:-1 维度自动计算
// TensorFlow 允许一个维度设为 -1,表示根据其他维度自动推断
int num_output_elements = 1;
int stretch_dim = -1;
for (int i = 0; i < output_shape->size; ++i) {
int value = output_shape->data[i];
if (value == -1) {
TF_LITE_ENSURE_EQ(context, stretch_dim, -1); // 确保只有一个 -1 维度
stretch_dim = i;
} else {
num_output_elements *= value;
}
}
// 如果存在 -1 维度,自动计算其大小
if (stretch_dim != -1) {
TfLiteEvalTensor* output_eval =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_STATUS(tflite::micro::CreateWritableTensorDimsWithCopy(
context, output, output_eval));
output_shape = output->dims; // 更新形状指针
output_shape->data[stretch_dim] = num_input_elements / num_output_elements;
num_output_elements *= output_shape->data[stretch_dim];
}
// 确保输入输出元素数量一致 - Reshape 不改变元素总数
TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); // 确保数据类型一致
// 释放临时张量 - 避免内存泄漏
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
函数作用详解:
临时张量分配:使用 `AllocateTempInputTensor` 和 `AllocateTempOutputTensor` 获取张量信息,这些是临时分配,不会占用持久内存
形状验证:确保 reshape 操作的合法性,输入输出元素总数必须相等
自动维度推断:处理 -1 维度的特殊情况,这是 TensorFlow 的标准特性
内存管理:及时释放临时张量,这在内存受限的微控制器环境中非常重要
Prepare 函数-操作符准备阶段:
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// 验证输入输出数量 - Reshape 可以有 1 个或 2 个输入(第二个输入是可选的形状参数)
TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); // 只有一个输出
// 执行输出形状重塑 - 在准备阶段确定最终的输出形状
TF_LITE_ENSURE_EQ(context, ReshapeOutput(context, node), kTfLiteOk);
return kTfLiteOk;
}
Prepare函数说明:
输入验证:Reshape 操作支持 1-2 个输入,第二个输入是可选的形状张量
形状计算:在准备阶段就确定输出形状,避免在执行阶段重复计算
错误检查:使用 `TF_LITE_ENSURE` 宏进行参数验证,失败时会返回错误状态
Eval 函数-操作符执行阶段:
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// 获取输入输出张量 - 使用 EvalTensor 进行实际计算
const TfLiteEvalTensor* input =
tflite::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
tflite::GetEvalOutput(context, node, kOutputTensor);
// 计算输入数据大小 - 需要拷贝的字节数
size_t input_bytes;
TF_LITE_ENSURE_STATUS(TfLiteTypeSizeOf(input->type, &input_bytes));
input_bytes *= ElementCount(*input->dims);
// 执行数据拷贝(如果不是原地操作)
// 原地操作:输入输出使用同一块内存,无需拷贝
if (input->data.raw != output->data.raw) {
memcpy(output->data.raw, input->data.raw, input_bytes);
}
return kTfLiteOk;
}
Eval函数说明:
EvalTensor 使用:在执行阶段使用 `TfLiteEvalTensor`,它包含实际的数据指针
原地操作优化:检查输入输出是否共享内存,避免不必要的数据拷贝
内存拷贝:Reshape 操作本质上只是改变数据的解释方式,不改变数据内容
2.4 操作符注册函数
TfLiteRegistration_V1 Register_RESHAPE() {
return tflite::micro::RegisterOp(nullptr, reshape::Prepare, reshape::Eval);
}
} // namespace reshape
} // namespace micro
} // namespace ops
} // namespace tflite
注册函数说明:
RegisterOp 函数:创建操作符注册结构体,包含初始化、准备和执行函数指针
nullptr 参数:第一个参数是初始化函数,Reshape 不需要特殊初始化,所以传入 nullptr
函数指针:传入 Prepare 和 Eval 函数指针,框架会在适当时机调用这些函数
3. 参数解析实现
3.1 解析函数声明
文件位置:`core/api/flatbuffer_conversions.h`
TfLiteStatus ParseReshape(const Operator* op, ErrorReporter* error_reporter, BuiltinDataAllocator* allocator, void builtin_data);
声明说明:
Operator* op:来自 FlatBuffer 的操作符定义,包含所有参数信息
ErrorReporter:用于报告解析过程中的错误
BuiltinDataAllocator:专用的内存分配器,用于分配参数结构体
builtin_data:输出参数,指向解析后的参数结构体
3.2 解析函数实现
文件位置:`core/api/flatbuffer_conversions.cpp`
TfLiteStatus ParseReshape(const Operator* op, ErrorReporter* error_reporter, BuiltinDataAllocator* allocator, void builtin_data) ;
解析函数详解:
参数验证:`CheckParsePointerParams` 确保所有指针参数有效
安全分配器:`SafeBuiltinDataAllocator` 提供异常安全的内存分配
FlatBuffer 解析:从序列化的模型文件中提取 reshape 参数
格式转换:将 FlatBuffer 格式转换为 TFLite 内部使用的 C 结构体格式
所有权转移:使用 `release()` 将参数结构体的所有权转移给框架
3.3 在解析开关中添加对应的case
文件位置:`core/api/flatbuffer_conversions.cpp`
在 `ParseOpData` 函数的 switch 语句中添加:
case BuiltinOperator_RESHAPE: {
return ParseReshape(op, error_reporter, allocator, builtin_data);
}
开关语句说明:
这个 switch 语句是 TFLite 参数解析的核心分发机制
根据操作符类型调用相应的解析函数
`BuiltinOperator_RESHAPE` 是在FlatBuffer schema中定义的枚举值
通过本指南,我们深入了解了 TensorFlow Lite Micro 的操作符注册机制,包括其设计理念、实现方式以及在嵌入式场景中的重要性。
未来,随着边缘计算和微控制器 AI 的快速发展,理解并运用这些底层机制将成为构建高效、可扩展 AI 系统的核心能力。建议读者在实践中尝试自定义算子注册,并结合实际项目进行优化,以真正发挥 TensorFlow Lite Micro 的潜力。
全部0条评论
快来发表一下你的评论吧 !