如何为TensorFlow Lite Micro添加多输入多输出支持(二)

描述

在上一篇文章中,我们已经带大家了解了多输入多输出(MIMO)能力的架构设计思路。

今天,小编将继续深入解析如何将架构设计真正落地到可运行代码,并带来一套可复用的核心实现。会介绍多输入多输出支持框架的关键组成部分。通过清晰的结构化设计、类型安全的接口抽象,为复杂的嵌入式 AI 模型建立一个高扩展性、高可维护性的基础底座。

下面,我们就将通过头文件设计、基础数据结构构建、生命周期管理等内容一步步展示一个完整的MIMO支持框架是如何搭建起来的。

 话不多说,上代码!(代码超载预警)

  头文件设计:构建类型安全的基础

首先,我们需要一个类型和接口定义完备、可扩展性强的头文件 model.h。
这一部分为后续的MIMO管理、张量访问、预处理、模型统计等功能奠定了坚实基础。

 

#ifndef MODEL_H
#define MODEL_H
 
#include "tensorflow/lite/c/common.h"
//   =============================================================================
// 配置常量
//   =============================================================================
#define MAX_INPUT_TENSORS 8     // 最大输入张量数量
#define MAX_OUTPUT_TENSORS 8    // 最大输出张量数量  
#define MAX_TENSOR_DIMS 6       // 最大张量维度数
#define MODEL_NAME_MAX_LEN 64   // 模型名称最大长度
//   =============================================================================
// 状态码定义
//   =============================================================================
typedef enum {
    kStatus_Success   = 0,
    kStatus_Fail =   1,
     kStatus_InvalidParam = 2,
     kStatus_OutOfRange = 3,
     kStatus_NotInitialized = 4,
     kStatus_InsufficientMemory = 5
} status_t;
//   =============================================================================
// 张量相关类型定义
//   =============================================================================
typedef enum {
     kTensorType_FLOAT32 = 0,
     kTensorType_UINT8 = 1,
    kTensorType_INT8   = 2,
     kTensorType_INT32 = 3,
    kTensorType_BOOL   = 4,
     kTensorType_UNKNOWN = 255
} tensor_type_t;
typedef struct {
    int size;                           // 维度数量
    int   data[MAX_TENSOR_DIMS];         // 各维度的大小
} tensor_dims_t;
// 单个张量的完整信息
typedef struct {
    int index;                         // 张量索引
    tensor_dims_t   dims;               // 维度信息
    tensor_type_t   type;               // 数据类型
    uint8_t*   data;                    // 数据指针
    size_t   size_bytes;                // 数据大小(字节)
    const char*   name;                 // 张量名称(可选)
} tensor_info_t;
// 多张量信息结构
typedef struct {
    int count;                                    // 张量数量
    tensor_info_t   tensors[MAX_INPUT_TENSORS];    // 张量信息数组
} multi_tensor_info_t;
//   =============================================================================
// 模型统计信息
//   =============================================================================
typedef struct {
    size_t   arena_used_bytes;          // 已使用的内存
    size_t   arena_total_bytes;         // 总内存大小
    int   input_count;                  // 输入张量数量
    int   output_count;                 // 输出张量数量
    const char*   model_name;           // 模型名称
} model_stats_t;
//   =============================================================================
// 核心接口声明
//   =============================================================================
// 模型生命周期管理
status_t MODEL_Init(void);
status_t MODEL_Deinit(void);
status_t MODEL_RunInference(void);
// 模型信息查询
int MODEL_GetInputTensorCount(void);
int MODEL_GetOutputTensorCount(void);
status_t MODEL_GetModelStats(model_stats_t* stats);
const char* MODEL_GetModelName(void);
// 单张量操作接口
uint8_t* MODEL_GetInputTensorData(int index,   tensor_dims_t* dims, tensor_type_t* type);
uint8_t* MODEL_GetOutputTensorData(int index,   tensor_dims_t* dims, tensor_type_t* type);
// 增强的单张量接口
status_t MODEL_GetInputTensorInfo(int index,   tensor_info_t* info);
status_t MODEL_GetOutputTensorInfo(int index,   tensor_info_t* info);
// 批量操作接口
status_t   MODEL_GetAllInputTensors(multi_tensor_info_t* input_info);
status_t   MODEL_GetAllOutputTensors(multi_tensor_info_t* output_info);
// 数据预处理接口
status_t   MODEL_ConvertInput(int tensor_index, uint8_t* data,
                           const   tensor_dims_t* dims, tensor_type_t type);
// 工具函数
size_t   MODEL_GetTensorSizeBytes(const tensor_dims_t* dims, tensor_type_t type);
const char*   MODEL_GetTensorTypeName(tensor_type_t type);
status_t   MODEL_ValidateTensorDims(const tensor_dims_t* dims);
#endif //   MODEL_H
  核心实现:从设计到代码

 

接下来,进入到实际实现部分。为了提高代码可读性,整体实现拆分为以下模块:

全局变量与初始化

内部工具函数

生命周期管理(Init / Deinit / Invoke)

全局变量和初始化:

 

#include   "tensorflow/lite/micro/kernels/micro_ops.h"
#include   "tensorflow/lite/micro/micro_interpreter.h"
#include   "tensorflow/lite/micro/micro_op_resolver.h"
#include   "tensorflow/lite/schema/schema_generated.h"
 
#include "fsl_debug_console.h"
#include "model.h"
#include "model_data.h"
 
//   =============================================================================
// 全局变量
//   =============================================================================
static const tflite::Model* s_model = nullptr;
static tflite::MicroInterpreter* s_interpreter = nullptr;
static bool s_model_initialized = false;
 
// 张量内存区域 - 根据具体模型调整大小
static uint8_t s_tensorArena[kTensorArenaSize]   __ALIGNED(16);
 
// 外部函数声明
extern tflite::MicroOpResolver   &MODEL_GetOpsResolver();
 
//   =============================================================================
// 内部辅助函数
//   =============================================================================
 
// 获取数据类型的字节大小
static size_t GetTypeSize(tensor_type_t type)
{
    switch (type) {
        case   kTensorType_FLOAT32:
        case   kTensorType_INT32:
            return   4;
        case   kTensorType_UINT8:
        case   kTensorType_INT8:
        case   kTensorType_BOOL:
            return   1;
        default:
            return   0;
    }
}
 
// TensorFlow Lite类型转换为我们的类型
static tensor_type_t ConvertTfLiteType(TfLiteType tf_type)
{
    switch (tf_type)   {
        case   kTfLiteFloat32:
            return   kTensorType_FLOAT32;
        case   kTfLiteUInt8:
            return   kTensorType_UINT8;
        case   kTfLiteInt8:
            return   kTensorType_INT8;
        case   kTfLiteInt32:
            return   kTensorType_INT32;
        case   kTfLiteBool:
            return   kTensorType_BOOL;
        default:
            return   kTensorType_UNKNOWN;
    }
}
 
// 从TensorFlow Lite张量提取信息
static status_t ExtractTensorInfo(TfLiteTensor* tf_tensor,   int index, tensor_info_t* info)
{
    if (tf_tensor ==   nullptr || info == nullptr) {
        return   kStatus_InvalidParam;
    }
 
    // 基本信息
    info->index =   index;
    info->type =   ConvertTfLiteType(tf_tensor->type);
    info->data =   tf_tensor->data.uint8;
    
    if   (info->type == kTensorType_UNKNOWN) {
         PRINTF("Unsupported tensor type: %d
",   tf_tensor->type);
        return   kStatus_Fail;
    }
 
    // 维度信息
     info->dims.size = tf_tensor->dims->size;
    if   (info->dims.size > MAX_TENSOR_DIMS) {
         PRINTF("Tensor dimensions exceed maximum: %d > %d
",
                info->dims.size, MAX_TENSOR_DIMS);
        return   kStatus_OutOfRange;
    }
 
    size_t   total_elements = 1;
    for (int i = 0;   i < info->dims.size; i++) {
         info->dims.data[i] = tf_tensor->dims->data[i];
         total_elements *= info->dims.data[i];
    }
 
    // 计算数据大小
     info->size_bytes = total_elements * GetTypeSize(info->type);
    
    // 张量名称(如果可用)
    info->name =   nullptr;  // TensorFlow Lite Micro通常不保存名称
 
    return   kStatus_Success;
}
  模型生命周期管理

 

这部分主要包括:

模型初始化(加载模型 / 创建解释器 / 分配张量内存)

模型反初始化

执行推理(Invoke)

 

//
模型生命周期管理
//
 
status_t MODEL_Init(void)
{
    if   (s_model_initialized) {
         PRINTF("Model already initialized
");
        return   kStatus_Success;
    }
 
    // 加载模型
    s_model =   tflite::GetModel(model_data);
    if   (s_model->version() != TFLITE_SCHEMA_VERSION) {
         PRINTF("Model schema version %d not supported (expected   %d)
",
                s_model->version(), TFLITE_SCHEMA_VERSION);
        return   kStatus_Fail;
    }
 
    // 获取操作解析器
     tflite::MicroOpResolver µ_op_resolver =   MODEL_GetOpsResolver();
 
    // 创建解释器
    static   tflite::MicroInterpreter static_interpreter(
        s_model,   micro_op_resolver, s_tensorArena, kTensorArenaSize);
    s_interpreter =   &static_interpreter;
 
    // 分配张量内存
    TfLiteStatus   allocate_status = s_interpreter->AllocateTensors();
    if   (allocate_status != kTfLiteOk) {
         PRINTF("AllocateTensors() failed with status: %d
",   allocate_status);
        return   kStatus_InsufficientMemory;
    }
 
     s_model_initialized = true;
 
    // 打印模型信息
     PRINTF("Model '%s' initialized successfully:
",   MODEL_GetModelName());
    PRINTF("-   Input tensors: %d
", s_interpreter->inputs_size());
    PRINTF("-   Output tensors: %d
", s_interpreter->outputs_size());
    PRINTF("-   Arena used: %zu bytes
", s_interpreter->arena_used_bytes());
 
    return   kStatus_Success;
}
 
status_t MODEL_Deinit(void)
{
    if   (!s_model_initialized) {
        return   kStatus_NotInitialized;
    }
 
    // TensorFlow   Lite Micro使用静态内存,无需显式释放
    s_model =   nullptr;
    s_interpreter =   nullptr;
     s_model_initialized = false;
 
     PRINTF("Model deinitialized
");
    return   kStatus_Success;
}
 
status_t MODEL_RunInference(void)
{
    if   (!s_model_initialized || s_interpreter == nullptr) {
         PRINTF("Model not initialized
");
        return   kStatus_NotInitialized;
    }
 
    TfLiteStatus   invoke_status = s_interpreter->Invoke();
    if   (invoke_status != kTfLiteOk) {
         PRINTF("Model inference failed with status: %d
",   invoke_status);
        return   kStatus_Fail;
    }
 
    return   kStatus_Success;
}
 
  信息查询接口

 

包含:

输入/输出张量数量查询

模型统计信息读取

模型名称查询

 

//
 模型信息查询
//
 
int MODEL_GetInputTensorCount(void)
{
    if   (!s_model_initialized || s_interpreter == nullptr) {
        return 0;
    }
    return   s_interpreter->inputs_size();
}
 
int MODEL_GetOutputTensorCount(void)
{
    if   (!s_model_initialized || s_interpreter == nullptr) {
        return 0;
    }
    return   s_interpreter->outputs_size();
}
 
status_t   MODEL_GetModelStats(model_stats_t* stats)
{
    if (stats ==   nullptr) {
        return   kStatus_InvalidParam;
    }
 
    if   (!s_model_initialized || s_interpreter == nullptr) {
        return   kStatus_NotInitialized;
    }
 
     stats->arena_used_bytes = s_interpreter->arena_used_bytes();
     stats->arena_total_bytes = kTensorArenaSize;
     stats->input_count = s_interpreter->inputs_size();
     stats->output_count = s_interpreter->outputs_size();
     stats->model_name = MODEL_GetModelName();
 
    return   kStatus_Success;
}
 
const char* MODEL_GetModelName(void)
{
    return   MODEL_NAME;
}
 
  下期预告

 

由于篇幅有限,本篇重点展示了:

头文件设计:类型安全、结构清晰

核心实现框架:生命周期管理 + 内部工具函数

基本模型信息查询接口

在下一篇(系列最终章)中,我们将重点讲解:

张量数据访问接口(Input/Output Data APIs)完整实现

批量张量操作的高效实现方案

更实际的代码示例与最佳实践

打开APP阅读更多精彩内容
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

全部0条评论

快来发表一下你的评论吧 !

×
20
完善资料,
赚取积分