TensorFlow的Yes/No 关键词识别模型训练
TensorFlow的Yes/No 关键词识别模型训练
参考 TensorFlow 官方教程的 Yes/No 关键词识别模型训练脚本,可以生成直接替换原有mirco_speech识别模型数据的C文件。
参考来源:https://github.com/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/examples/micro_speech/train/train_micro_speech_model.ipynb
模型大小20k byte左右
训练脚本speech_trainer.py 使用说明
speech_trainer.py 脚本源码在文章最后
脚本概述
speech_trainer.py
提供一条龙流水线:
- 自动准备数据集(首次或缺失时自动下载/解压 Speech Commands v0.02)
- 克隆
tensorflow
仓库并调用官方train.py
训练 - 冻结模型为 SavedModel(兼容 v1 风格,内部调用官方
freeze.py
) - 生成 TFLite 浮点与量化模型,并做精度验证
- 生成可直接用于 MCU 的
C
源文件micro_speech_quantized_model_data.c
目录与输出位置:
dataset/
:数据集(自动管理,不会被清理)train/
:训练输出(checkpoint 等),支持断点续训logs/
:日志与事件文件models/
:saved_model/
:冻结后的 SavedModelmicro_speech_quantized.tflite
:量化 TFLite 模型micro_speech_float.tflite
:浮点 TFLite 模型micro_speech_quantized_model_data.c
:TFLite Micro C 数组文件
目录结构
train/
├── dataset/ # 数据集(自动管理)
├── logs/ # 训练日志与事件文件
├── models/ # 导出模型
│ ├── saved_model/ # 冻结 SavedModel
│ ├── micro_speech_quantized.tflite # 量化 TFLite 模型
│ ├── micro_speech_float.tflite # 浮点 TFLite 模型
│ └── micro_speech_quantized_model_data.c # MCU 可用的 C 数组
├── tensorflow/ # 克隆的 TF 仓库(含官方 scripts)
└── speech_trainer.py # 本脚本
环境要求
- Python 3.10(Windows 建议
py -3.10
) - 安装依赖:
Windows(PowerShell):
# 安装Python 3.10
winget install -e --id Python.Python.3.10
# 创建Python虚拟环境
py -3.10 -m venv .venvpy310_win
# 进入Python虚拟环境
.venvpy310_win\Scripts\Activate.ps1
# 更新pip
python -m pip install --upgrade pip
# 安装依赖包
pip install -r requirements.txt
Linux/macOS(bash):
python3.10 -m venv .venvpy310
source .venvpy310/bin/activate
python -m pip install --upgrade pip
pip install -r requirements.txt
快速开始(3 步)
- 准备环境(见上)并激活 venv
- 一条命令启动训练:
python train/speech_trainer.py
- 结束后在
train/models/
获取micro_speech_quantized_model_data.c
、micro_speech_quantized.tflite
等产物
数据集管理逻辑(自动)
- 若
dataset/
不存在:创建并下载speech_commands_v0.02.tar.gz
到dataset/
,随后解压。 - 若
dataset/
已存在:- 若压缩包不存在,则先下载
- 若未解压(通过是否存在
yes/
、no/
、_background_noise_/
判定)则解压
你无需手动干预,脚本会在流水线开始自动确保数据集就绪。
资源与磁盘建议:
- 磁盘空间:≥ 3 GB(含数据集与中间文件)
- 内存:≥ 4 GB(更高内存更稳定)
- 网络:可访问
storage.googleapis.com
(如网络受限请配置代理或手动下载)
命令行参数
--wanted_words
:要训练的词汇(逗号分隔)。默认:yes,no
--training_steps
:训练步数字符串(逗号分段)。默认:12000,3000
--learning_rate
:学习率字符串(逗号分段)。默认:0.001,0.0001
--model_architecture
:模型架构。可选:single_fc
、conv
、low_latency_conv
、low_latency_svdf
、tiny_embedding_conv
、tiny_conv
(默认tiny_conv
)--skip_training
:跳过训练,直接下载官方预训练模型并进入后续转换/导出--resume
:继续上次训练(保留train/
,自动查找最近 checkpoint 作为--start_checkpoint
)--test_env
:仅测试环境(依赖/路径检查),不执行训练
说明:
- 训练时总步数为各分段之和(例如
12000,3000
=> 总步数15000
)。 - 续训时若未找到 checkpoint,将从头开始训练(脚本会提示)。
运行示例
- 基础训练(推荐)
python train/speech_trainer.py
- 指定词表与步数
python train/speech_trainer.py \--wanted_words yes,no \--training_steps 12000,3000 \--learning_rate 0.001,0.0001 \--model_architecture tiny_conv
- 继续上次训练
python train/speech_trainer.py --resume
- 使用预训练模型(跳过训练)
python train/speech_trainer.py --skip_training
- 仅测试环境
python train/speech_trainer.py --test_env
续训与从零重训
- 继续训练:使用
--resume
,自动寻找train/
下步数最大的*.ckpt-*.index
- 从零重训:删除
train/
与logs/
再运行;或不删目录直接不加--resume
- 仅重新导出:删除
models/
并重跑(会跳过训练,直接冻结与导出)
流水线阶段说明
- 确保数据集:自动下载/解压 Speech Commands v0.02
- 克隆
tensorflow
仓库(若已存在则跳过) - 调用官方
train.py
进行训练(可续训) - 调用官方
freeze.py
生成saved_model/
- 生成
micro_speech_float.tflite
与micro_speech_quantized.tflite
(量化),并进行精度评估 - 生成
micro_speech_quantized_model_data.c
(TFLite Micro C 数组) - 打印各输出文件路径与大小
集成到 MCU 工程
- 训练完成后,
models/micro_speech_quantized_model_data.c
即为可直接集成的模型数据文件。 - 将其复制到你的工程对应目录(例如
kws/
),替换旧模型文件后编译。
常见问题(FAQ)
- 导出目录已存在:脚本已处理为不预创建
saved_model
子目录,如仍遇到该错误,可手动删除models/saved_model/
后重试。 - 续训未找到 checkpoint:确认
train/
下存在形如tiny_conv.ckpt-*.index
文件;否则将从头训练。 - 数据集下载失败:检查网络或手动下载
speech_commands_v0.02.tar.gz
到dataset/
目录后重跑。 - 量化模型精度下降明显:可适当增大代表性数据采样数量或调整训练步数与学习率。
其它提示:
- Windows 执行策略:若激活虚拟环境报策略限制,可在管理员 PowerShell 运行:
Set-ExecutionPolicy -Scope CurrentUser RemoteSigned
- 国内网络下载慢/失败:可预先手动下载
speech_commands_v0.02.tar.gz
到train/dataset/
再运行。 - TF CPU 指令集提示(SSE/AVX 等):为性能提示,可忽略,不影响功能。
speech_trainer.py代码
#!/usr/bin/env python3
"""
语音识别模型训练程序
基于 TensorFlow 的简单音频识别模型训练脚本
支持生成 TensorFlow Lite 模型用于微控制器部署
"""import os
import sys
import subprocess
import argparse
import logging
from pathlib import Path
import locale# 设置编码,解决Windows中文路径问题
if sys.platform == 'win32':import codecs# 设置控制台编码if sys.stdout.encoding != 'utf-8':sys.stdout = codecs.getwriter('utf-8')(sys.stdout.buffer, 'strict')if sys.stderr.encoding != 'utf-8':sys.stderr = codecs.getwriter('utf-8')(sys.stderr.buffer, 'strict')# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)# 延迟导入
numpy = None
tensorflow = Noneclass SpeechRecognitionTrainer:def __init__(self, config=None):"""初始化训练器配置"""# 检查并安装依赖self._check_and_install_dependencies()# 全局导入global numpy, tensorflowimport numpyimport tensorflow# 默认配置self.config = {# 训练参数'wanted_words': 'yes,no','training_steps': '1000,1000','learning_rate': '0.001,0.0001',# 模型参数'preprocess': 'micro','window_stride': 20,'model_architecture': 'tiny_conv',# 训练控制参数'verbosity': 'INFO','eval_step_interval': '1000','save_step_interval': '1000',# 数据参数'sample_rate': 16000,'clip_duration_ms': 1000,'window_size_ms': 30.0,'feature_bin_count': 40,'background_frequency': 0.8,'background_volume_range': 0.1,'time_shift_ms': 100.0,'validation_percentage': 10,'testing_percentage': 10,# 量化参数'quant_input_min': 0.0,'quant_input_max': 26.0,}# 更新配置if config:self.config.update(config)# 计算派生参数self._calculate_derived_params()# 设置目录路径self._setup_directories()# 运行时状态self.resume = bool(self.config.get('resume', False))def _find_latest_checkpoint(self):"""返回训练目录下最新的 checkpoint 路径(不含扩展名),找不到则返回 None"""from pathlib import Pathtrain_dir = Path(self.directories['train'])model_prefix = self.config['model_architecture'] + '.ckpt-'candidates = []for index_file in train_dir.glob(f"{self.config['model_architecture']}.ckpt-*.index"):name = index_file.name # e.g. tiny_conv.ckpt-12345.indextry:step_str = name.split('.ckpt-')[-1].split('.index')[0]step = int(step_str)candidates.append((step, index_file))except Exception:continueif not candidates:return Nonecandidates.sort(key=lambda x: x[0], reverse=True)latest_index = candidates[0][1]# 去掉 .index 扩展名return str(latest_index.with_suffix(''))def _check_and_install_dependencies(self):"""检查并安装必要的依赖"""logger.info("检查并安装依赖...")required_packages = ['numpy','tensorflow','matplotlib','scipy',]missing_packages = []for package in required_packages:try:__import__(package)logger.info(f"✓ {package} 已安装")except ImportError:missing_packages.append(package)logger.warning(f"✗ {package} 未安装")if missing_packages:logger.info(f"安装缺失的包: {', '.join(missing_packages)}")for package in missing_packages:try:subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])logger.info(f"✓ {package} 安装成功")except subprocess.CalledProcessError as e:logger.error(f"✗ {package} 安装失败: {e}")raiselogger.info("依赖检查完成")def _calculate_derived_params(self):"""计算派生参数"""steps = self.config['training_steps'].split(',')self.config['total_steps'] = str(sum(int(step) for step in steps))number_of_labels = self.config['wanted_words'].count(',') + 1number_of_total_labels = number_of_labels + 2equal_percentage = int(100.0 / number_of_total_labels)self.config['silent_percentage'] = equal_percentageself.config['unknown_percentage'] = equal_percentageself.config['quant_input_range'] = self.config['quant_input_max'] - self.config['quant_input_min']def _setup_directories(self):"""设置目录路径 - 使用绝对路径避免相对路径问题"""# 获取当前工作目录的绝对路径base_dir = os.path.abspath(os.getcwd())self.directories = {'dataset': os.path.join(base_dir, 'dataset'),'logs': os.path.join(base_dir, 'logs'),'train': os.path.join(base_dir, 'train'),'models': os.path.join(base_dir, 'models'),'tensorflow_repo': os.path.join(base_dir, 'tensorflow')}# 模型文件路径models_dir = self.directories['models']self.model_paths = {'saved_model': os.path.join(models_dir, 'saved_model'),'model_tflite': os.path.join(models_dir, 'micro_speech_quantized.tflite'),'float_model_tflite': os.path.join(models_dir, 'micro_speech_float.tflite'),'model_tflite_micro': os.path.join(models_dir, 'micro_speech_quantized_model_data.c'),}def clean_previous_data(self):"""清理之前的训练数据"""logger.info("清理之前的训练数据...")import shutilfor name, directory in self.directories.items():if name in ('tensorflow_repo', 'dataset'):logger.info(f"跳过目录: {name} -> {directory}")continueif self.resume and name == 'train':logger.info(f"检测到 --resume,跳过清理训练目录: {directory}")continueif os.path.exists(directory):try:shutil.rmtree(directory)logger.info(f"已删除目录: {directory}")except Exception as e:logger.warning(f"无法删除 {directory}: {e}")# 创建必要目录os.makedirs(self.directories['models'], exist_ok=True)os.makedirs(self.directories['dataset'], exist_ok=True)def ensure_dataset(self):"""确保 dataset 目录存在并包含 speech_commands_v0.02 数据集规则:1) 若 dataset 不存在,则创建并下载压缩包到其中,然后解压;2) 若 dataset 存在:检查是否已有压缩包;若无则下载;随后若未解压则解压。"""import urllib.requestimport tarfiledataset_dir = Path(self.directories['dataset'])dataset_dir.mkdir(parents=True, exist_ok=True)archive_name = 'speech_commands_v0.02.tar.gz'archive_path = dataset_dir / archive_namedata_url = 'https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz'# 判断是否已解压(通过常见子目录存在性来粗略判断)def is_extracted(path: Path) -> bool:common_subdirs = ['yes', 'no', '_background_noise_']return any((path / sub).exists() for sub in common_subdirs)# 如无压缩包则下载if not archive_path.exists():logger.info(f"下载数据集到: {archive_path}")try:urllib.request.urlretrieve(data_url, str(archive_path))except Exception as e:logger.error(f"数据集下载失败: {e}")raise# 如未解压则解压if not is_extracted(dataset_dir):logger.info("解压数据集...")try:with tarfile.open(str(archive_path), 'r:gz') as tar:tar.extractall(str(dataset_dir))except Exception as e:logger.error(f"解压失败: {e}")raiselogger.info("数据集解压完成")def setup_tensorflow_repo(self):"""克隆 TensorFlow 仓库"""if not os.path.exists(self.directories['tensorflow_repo']):logger.info("克隆 TensorFlow 仓库...")subprocess.run(['git', 'clone', '-q', '--depth', '1', 'https://github.com/tensorflow/tensorflow'], check=True)logger.info("TensorFlow 仓库克隆完成")else:logger.info("TensorFlow 仓库已存在")def train_model(self):"""训练模型 - 修复路径问题"""logger.info("开始训练模型...")logger.info(f"训练词汇: {self.config['wanted_words']}")logger.info(f"训练步数: {self.config['training_steps']}")logger.info(f"学习率: {self.config['learning_rate']}")logger.info(f"总步数: {self.config['total_steps']}")# 使用Path对象处理路径train_script = Path(self.directories['tensorflow_repo']) / 'tensorflow' / 'examples' / 'speech_commands' / 'train.py'train_cmd = [sys.executable,str(train_script), # 转换为字符串f"--data_dir={self.directories['dataset']}",f"--wanted_words={self.config['wanted_words']}",f"--silence_percentage={self.config['silent_percentage']}",f"--unknown_percentage={self.config['unknown_percentage']}",f"--preprocess={self.config['preprocess']}",f"--window_stride={self.config['window_stride']}",f"--model_architecture={self.config['model_architecture']}",f"--how_many_training_steps={self.config['training_steps']}",f"--learning_rate={self.config['learning_rate']}",f"--train_dir={self.directories['train']}",f"--summaries_dir={self.directories['logs']}",f"--verbosity={self.config['verbosity']}",f"--eval_step_interval={self.config['eval_step_interval']}",f"--save_step_interval={self.config['save_step_interval']}"]# 如需继续训练,自动带上最近 checkpointif self.resume:latest_ckpt = self._find_latest_checkpoint()if latest_ckpt:train_cmd.append(f"--start_checkpoint={latest_ckpt}")logger.info(f"继续训练:使用最近 checkpoint: {latest_ckpt}")else:logger.info("--resume 指定但未找到 checkpoint,将从头开始训练")# 设置环境变量env = os.environ.copy()speech_commands_path = str(Path(self.directories['tensorflow_repo']) / 'tensorflow' / 'examples' / 'speech_commands')env['PYTHONPATH'] = speech_commands_path + os.pathsep + env.get('PYTHONPATH', '')# Windows特定:设置编码if sys.platform == 'win32':env['PYTHONIOENCODING'] = 'utf-8'try:logger.info(f"执行训练命令...")result = subprocess.run(train_cmd, check=True, env=env,capture_output=False,text=True,encoding='utf-8' if sys.platform == 'win32' else None)logger.info("模型训练完成")except subprocess.CalledProcessError as e:logger.error(f"训练失败: {e}")raisedef freeze_model(self):"""冻结模型 - 修复路径和编码问题"""logger.info("冻结模型...")# 确保目标目录存在saved_model_dir = Path(self.model_paths['saved_model'])saved_model_parent = saved_model_dir.parentsaved_model_parent.mkdir(parents=True, exist_ok=True)# 如果saved_model目录已存在,删除它if saved_model_dir.exists():import shutilshutil.rmtree(saved_model_dir)# 注意:不要预创建 saved_model 目录或其子目录,# 让 SavedModelBuilder 在保存时自行创建,# 否则会触发 "Export directory already exists, and isn't empty" 错误。freeze_script = Path(self.directories['tensorflow_repo']) / 'tensorflow' / 'examples' / 'speech_commands' / 'freeze.py'# 优先使用最新 checkpoint;若找不到则退回到预期步数latest_ckpt = self._find_latest_checkpoint()if latest_ckpt:checkpoint_path = Path(latest_ckpt)logger.info(f"冻结将使用最近 checkpoint: {checkpoint_path}")else:checkpoint_path = Path(self.directories['train']) / f"{self.config['model_architecture']}.ckpt-{self.config['total_steps']}"logger.info(f"未找到最近 checkpoint,尝试使用预期路径: {checkpoint_path}")freeze_cmd = [sys.executable,str(freeze_script),f"--wanted_words={self.config['wanted_words']}",f"--window_stride_ms={self.config['window_stride']}",f"--preprocess={self.config['preprocess']}",f"--model_architecture={self.config['model_architecture']}",f"--start_checkpoint={str(checkpoint_path)}",f"--save_format=saved_model",f"--output_file={str(saved_model_dir)}" # 使用绝对路径]# 设置环境变量env = os.environ.copy()speech_commands_path = str(Path(self.directories['tensorflow_repo']) / 'tensorflow' / 'examples' / 'speech_commands')env['PYTHONPATH'] = speech_commands_path + os.pathsep + env.get('PYTHONPATH', '')# Windows特定设置if sys.platform == 'win32':env['PYTHONIOENCODING'] = 'utf-8'env['PYTHONUTF8'] = '1'try:logger.info(f"执行冻结命令...")logger.info(f"输出路径: {saved_model_dir}")# 使用Popen以更好地处理编码process = subprocess.Popen(freeze_cmd,env=env,stdout=subprocess.PIPE,stderr=subprocess.PIPE,text=True,encoding='utf-8',errors='replace' # 替换无法解码的字符)stdout, stderr = process.communicate()if process.returncode != 0:logger.error(f"冻结失败,返回码: {process.returncode}")if stdout:logger.error(f"标准输出: {stdout}")if stderr:logger.error(f"错误输出: {stderr}")raise subprocess.CalledProcessError(process.returncode, freeze_cmd)logger.info("模型冻结完成")except subprocess.CalledProcessError as e:logger.error(f"模型冻结失败: {e}")# 尝试创建一个简单的saved_model作为备选方案logger.info("尝试使用备选方案创建saved_model...")self._create_saved_model_fallback()def _create_saved_model_fallback(self):"""备选方案:直接从checkpoint创建saved_model"""try:import tensorflow as tflogger.info("使用备选方案创建saved_model...")# 添加speech_commands路径speech_commands_path = Path(self.directories['tensorflow_repo']) / 'tensorflow' / 'examples' / 'speech_commands'if str(speech_commands_path) not in sys.path:sys.path.insert(0, str(speech_commands_path))import modelsimport input_data# 准备模型设置model_settings = models.prepare_model_settings(len(input_data.prepare_words_list(self.config['wanted_words'].split(','))),self.config['sample_rate'], self.config['clip_duration_ms'], self.config['window_size_ms'],self.config['window_stride'], self.config['feature_bin_count'], self.config['preprocess'])# 重置默认图tf.compat.v1.reset_default_graph()with tf.compat.v1.Session() as sess:# 创建占位符fingerprint_size = model_settings['fingerprint_size']fingerprint_input = tf.compat.v1.placeholder(tf.float32, [None, fingerprint_size], name='fingerprint_input')# 构建模型(is_training=False 时只返回 logits 张量)logits = models.create_model(fingerprint_input,model_settings,self.config['model_architecture'],is_training=False)# 添加并命名 softmax 输出labels_softmax = tf.nn.softmax(logits, name='labels_softmax')# 恢复权重# 备选方案同样选择最近 checkpointlatest_ckpt = self._find_latest_checkpoint()if latest_ckpt:checkpoint_path = Path(latest_ckpt)else:checkpoint_path = Path(self.directories['train']) / f"{self.config['model_architecture']}.ckpt-{self.config['total_steps']}"saver = tf.compat.v1.train.Saver()saver.restore(sess, str(checkpoint_path))# 保存模型saved_model_path = Path(self.model_paths['saved_model'])# 使用TensorFlow 1.x的SavedModelBuilderbuilder = tf.compat.v1.saved_model.builder.SavedModelBuilder(str(saved_model_path))# 定义签名inputs = {'fingerprint_input': tf.compat.v1.saved_model.utils.build_tensor_info(fingerprint_input)}outputs = {'labels_softmax': tf.compat.v1.saved_model.utils.build_tensor_info(labels_softmax)}signature = tf.compat.v1.saved_model.signature_def_utils.build_signature_def(inputs=inputs,outputs=outputs,method_name=tf.compat.v1.saved_model.signature_constants.PREDICT_METHOD_NAME)builder.add_meta_graph_and_variables(sess,[tf.compat.v1.saved_model.tag_constants.SERVING],signature_def_map={tf.compat.v1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature})builder.save()logger.info(f"备选方案:saved_model已创建在 {saved_model_path}")except Exception as e:logger.error(f"备选方案也失败了: {e}")logger.error("建议尝试以下方法:")logger.error("1. 降级TensorFlow版本到2.13或更早")logger.error("2. 使用--skip_training选项下载预训练模型")raisedef convert_to_tflite(self):"""转换为 TensorFlow Lite 模型"""logger.info("转换为 TensorFlow Lite 模型...")# 添加路径speech_commands_path = Path(self.directories['tensorflow_repo']) / 'tensorflow' / 'examples' / 'speech_commands'if str(speech_commands_path) not in sys.path:sys.path.insert(0, str(speech_commands_path))try:import input_dataimport modelsimport numpy as npimport tensorflow as tf# 准备模型设置model_settings = models.prepare_model_settings(len(input_data.prepare_words_list(self.config['wanted_words'].split(','))),self.config['sample_rate'], self.config['clip_duration_ms'], self.config['window_size_ms'],self.config['window_stride'], self.config['feature_bin_count'], self.config['preprocess'])# 创建音频处理器data_url = 'https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz'audio_processor = input_data.AudioProcessor(data_url, self.directories['dataset'],self.config['silent_percentage'], self.config['unknown_percentage'],self.config['wanted_words'].split(','), self.config['validation_percentage'],self.config['testing_percentage'], model_settings, self.directories['logs'])with tf.compat.v1.Session() as sess:# 生成浮点模型logger.info("生成浮点 TensorFlow Lite 模型...")float_converter = tf.lite.TFLiteConverter.from_saved_model(str(Path(self.model_paths['saved_model'])))float_tflite_model = float_converter.convert()with open(self.model_paths['float_model_tflite'], 'wb') as f:float_model_size = f.write(float_tflite_model)logger.info(f"浮点模型大小: {float_model_size} 字节")# 生成量化模型logger.info("生成量化 TensorFlow Lite 模型...")converter = tf.lite.TFLiteConverter.from_saved_model(str(Path(self.model_paths['saved_model'])))converter.optimizations = [tf.lite.Optimize.DEFAULT]converter.inference_input_type = tf.int8converter.inference_output_type = tf.int8# 代表性数据集生成器def representative_dataset_gen():for i in range(100):data, _ = audio_processor.get_data(1, i*1, model_settings,self.config['background_frequency'],self.config['background_volume_range'],self.config['time_shift_ms'],'testing', sess)flattened_data = np.array(data.flatten(), dtype=np.float32).reshape(1, 1960)yield [flattened_data]converter.representative_dataset = representative_dataset_gentflite_model = converter.convert()with open(self.model_paths['model_tflite'], 'wb') as f:quantized_model_size = f.write(tflite_model)logger.info(f"量化模型大小: {quantized_model_size} 字节")return audio_processor, model_settingsexcept Exception as e:logger.error(f"TensorFlow Lite 转换失败: {e}")raisedef test_tflite_accuracy(self, audio_processor, model_settings):"""测试 TensorFlow Lite 模型精度"""logger.info("测试模型精度...")import numpy as npimport tensorflow as tfdef run_tflite_inference(tflite_model_path, model_type="Float"):# 加载测试数据np.random.seed(0)with tf.compat.v1.Session() as sess:test_data, test_labels = audio_processor.get_data(-1, 0, model_settings, self.config['background_frequency'],self.config['background_volume_range'],self.config['time_shift_ms'], 'testing', sess)test_data = np.expand_dims(test_data, axis=1).astype(np.float32)# 初始化解释器interpreter = tf.lite.Interpreter(tflite_model_path,experimental_op_resolver_type=tf.lite.experimental.OpResolverType.BUILTIN_REF)interpreter.allocate_tensors()input_details = interpreter.get_input_details()[0]output_details = interpreter.get_output_details()[0]# 对于量化模型,手动将输入数据从浮点转换为整数if model_type == "Quantized":input_scale, input_zero_point = input_details["quantization"]test_data = test_data / input_scale + input_zero_pointtest_data = test_data.astype(input_details["dtype"])correct_predictions = 0for i in range(len(test_data)):interpreter.set_tensor(input_details["index"], test_data[i])interpreter.invoke()output = interpreter.get_tensor(output_details["index"])[0]top_prediction = output.argmax()correct_predictions += (top_prediction == test_labels[i])accuracy = (correct_predictions * 100) / len(test_data)logger.info(f'{model_type} 模型精度: {accuracy:.2f}% (测试样本数={len(test_data)})')return accuracy# 测试模型try:float_accuracy = run_tflite_inference(self.model_paths['float_model_tflite'])except Exception as e:logger.error(f"浮点模型测试失败: {e}")float_accuracy = Nonetry:quantized_accuracy = run_tflite_inference(self.model_paths['model_tflite'], model_type='Quantized')except Exception as e:logger.error(f"量化模型测试失败: {e}")quantized_accuracy = Nonereturn float_accuracy, quantized_accuracydef generate_micro_model(self):"""生成微控制器 C 源文件"""logger.info("生成微控制器 C 源文件...")try:# 直接使用Python实现self._generate_c_file_python()logger.info(f"C 源文件已生成: {self.model_paths['model_tflite_micro']}")except Exception as e:logger.error(f"生成 C 源文件失败: {e}")raisedef _generate_c_file_python(self):"""使用 Python 生成 C 文件"""with open(self.model_paths['model_tflite'], 'rb') as f:model_data = f.read()# 生成 C 数组c_content = []c_content.append('/* Automatically generated by speech_trainer.py */')c_content.append('#include "micro_speech_quantized_model_data.h"')c_content.append('')c_content.append('const unsigned char micro_speech_quantized_tflite[] = {')# 将字节数据转换为 C 数组格式hex_values = []for i, byte in enumerate(model_data):if i % 16 == 0:hex_values.append('\n ')hex_values.append(f'0x{byte:02x}')if i < len(model_data) - 1:hex_values.append(',')if (i + 1) % 16 != 0 and i < len(model_data) - 1:hex_values.append(' ')c_content.append(''.join(hex_values))c_content.append('\n};')c_content.append(f'const unsigned int micro_speech_quantized_tflite_len = {len(model_data)};')c_content.append('')# 写入文件with open(self.model_paths['model_tflite_micro'], 'w', encoding='utf-8') as f:f.write('\n'.join(c_content))def download_pretrained_model(self):"""下载预训练模型"""logger.info("下载预训练模型...")try:import urllib.requestimport tarfilemodel_url = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_micro_train_2020_05_10.tgz"model_file = "speech_micro_train_2020_05_10.tgz"logger.info(f"从 {model_url} 下载模型...")urllib.request.urlretrieve(model_url, model_file)logger.info("解压模型文件...")with tarfile.open(model_file, 'r:gz') as tar:tar.extractall('.')os.remove(model_file)logger.info("预训练模型下载完成")return Trueexcept Exception as e:logger.error(f"下载预训练模型失败: {e}")return Falsedef print_model_info(self):"""打印模型信息"""logger.info("\n" + "="*50)logger.info("模型训练完成!")logger.info("="*50)for name, path in self.model_paths.items():if os.path.exists(path):if os.path.isfile(path):size = os.path.getsize(path)logger.info(f"{name}: {path} ({size} 字节)")else:logger.info(f"{name}: {path} (目录)")else:logger.info(f"{name}: {path} (未生成)")logger.info("\n部署到微控制器:")logger.info("1. 参考 TensorFlow Lite Micro 文档")logger.info("2. 更新 micro_model_settings.h 中的 kCategoryCount 和 kCategoryLabels")logger.info("3. 使用生成的 micro_speech_quantized_model_data.c 文件替换原有模型文件")def run_full_pipeline(self, skip_training=False):"""运行完整的训练流水线"""try:logger.info("开始完整的模型训练流水线...")# 确保数据集就绪self.ensure_dataset()# 清理之前的数据self.clean_previous_data()if skip_training:success = self.download_pretrained_model()if not success:logger.error("无法下载预训练模型,将执行完整训练")skip_training = Falseif not skip_training:# 设置 TensorFlow 仓库self.setup_tensorflow_repo()# 训练模型self.train_model()# 冻结模型self.freeze_model()else:logger.info("使用预训练模型,跳过训练步骤")# 转换为 TensorFlow Liteaudio_processor, model_settings = self.convert_to_tflite()# 测试模型精度self.test_tflite_accuracy(audio_processor, model_settings)# 生成微控制器模型self.generate_micro_model()# 打印模型信息self.print_model_info()logger.info("训练流水线完成!")except KeyboardInterrupt:logger.info("用户中断训练")raiseexcept Exception as e:logger.error(f"训练流水线失败: {e}")logger.error("可能的解决方案:")logger.error("1. 检查网络连接(下载数据集需要)")logger.error("2. 确保有足够的磁盘空间")logger.error("3. 检查 Python 环境和依赖")logger.error("4. 尝试使用 --skip_training 下载预训练模型")raisedef main():"""主函数"""parser = argparse.ArgumentParser(description='语音识别模型训练程序')parser.add_argument('--wanted_words', default='yes,no', help='要训练的词汇,用逗号分隔 (默认: yes,no)')parser.add_argument('--training_steps', default='12000,3000',help='训练步数,用逗号分隔 (默认: 12000,3000)')parser.add_argument('--learning_rate', default='0.001,0.0001',help='学习率,用逗号分隔 (默认: 0.001,0.0001)')parser.add_argument('--model_architecture', default='tiny_conv',choices=['single_fc', 'conv', 'low_latency_conv', 'low_latency_svdf', 'tiny_embedding_conv', 'tiny_conv'],help='模型架构 (默认: tiny_conv)')parser.add_argument('--skip_training', action='store_true',help='跳过训练,使用预训练模型')parser.add_argument('--resume', action='store_true',help='继续上次训练:保留 train 目录并从最近 checkpoint 恢复')parser.add_argument('--test_env', action='store_true',help='仅测试环境,不执行训练')args = parser.parse_args()if args.test_env:# 仅测试环境logger.info("测试环境配置...")try:trainer = SpeechRecognitionTrainer()logger.info("✓ 环境测试通过")logger.info("✓ 所有依赖已正确安装")logger.info("可以开始训练模型")except Exception as e:logger.error(f"✗ 环境测试失败: {e}")sys.exit(1)return# 配置训练参数config = {'wanted_words': args.wanted_words,'training_steps': args.training_steps,'learning_rate': args.learning_rate,'model_architecture': args.model_architecture,'resume': args.resume,}# 创建训练器trainer = SpeechRecognitionTrainer(config)# 运行完整训练流水线trainer.run_full_pipeline(skip_training=args.skip_training)if __name__ == '__main__':main()