Alex_McAvoy

想要成为渔夫的猎手

NVIDIA 7th SkyHackathon(四)Nemo ASR 模型训练与评估

模型加载

导入 NeMo

1
2
3
4
5
6
7
8
9
10
11
12
import nemo
import nemo.collection.asr as nemo_asr
import torch

# 检查 nemo 版本 '1.4.0'
print(nemo.__version__)

# 检查 torch 版本 '1.12.1+cu113'
print(torch.__version__)

# 检查 GPU 是否被 torch 调用 True
print(torch.cuda.is_available())

加载自动语音识别模型

NeMo 的每个集合 ASR、NLP、TTS 中,都提供了许多预训练模型,使用 list_available_models() 可以查看 ASR 所提供的所有预训练模型

1
2
3
4
5
6
7
8
9
nemo_asr.models.EncDecCTCModel.list_available_models()
'''
[PretrainedModelInfo(
pretrained_model_name=QuartzNet15x5Base-En,
description=QuartzNet15x5 model trained on six datasets: LibriSpeech, Mozilla Common Voice (validated clips from en_1488h_2019-12-10), WSJ, Fisher, Switchboard, and NSC Singapore English. It was trained with Apex/Amp optimization level O1 for 600 epochs. The model achieves a WER of 3.79% on LibriSpeech dev-clean, and a WER of 10.05% on dev-other. Please visit https://ngc.nvidia.com/catalog/models/nvidia:nemospeechmodels for further details.,
location=https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/QuartzNet15x5Base-En.nemo
),
...]
'''

本次比赛 NVIDIA 在相关资料中提供了中文预训练模型 stt_zh_quartznet15x5.nemo,使用 restore_form() 进行加载

1
2
# 加载中文预训练模型并实例化
quartznet = nemo_asr.models.EncDecCTCModel.restore_from("stt_zh_quartznet15x5.nemo")

加载 quartznet 配置文件

使用 YAML 读取 quartznet 模型配置文件

1
2
3
4
5
6
7
8
9
10
11
try:
from ruamel.yaml import YAML
except ModuleNotFoundError:
from ruamel_yaml import YAML
config_path ="quartznet_15x5_zh.yaml"

yaml = YAML(typ='safe')
with open(config_path) as f:
params = yaml.load(f)

print(params)

加载数据清单并传给配置文件

将之前所制作的数据清单进行加载,并传给配置文件

1
2
3
4
5
6
7
# 加载数据清单
train_manifest = "/root/data/train.json"
test_manifest = "/root/data/val.json"

# 传递给配置文件
params['model']['train_ds']['manifest_filepath']=train_manifest
params['model']['validation_ds']['manifest_filepath']=test_manifest

模型训练

训练

使用迁移学习的方法训练模型

1
2
3
4
5
6
7
8
9
# 设置训练集
quartznet.setup_training_data(train_data_config=params['model']['train_ds'])
# 设置测试集
quartznet.setup_validation_data(val_data_config=params['model']['validation_ds'])

# 开始训练
import pytorch_lightning as pl
trainer = pl.Trainer(gpus=1,max_epochs=200)
trainer.fit(quartznet)

保存并重载

1
2
3
4
5
# 将训练好的模型保存为.nemo格式
quartznet.save_to("7th_asr_model_1.nemo")

# 重新加载模型
try_model_1 = nemo_asr.models.EncDecCTCModel.restore_from("7th_asr_model_1.nemo")

模型评估

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from ASR_metrics import utils as metrics

# 加载测试数据
asr_result = try_model_1.transcribe(paths2audio_files=["/root/data/test/1/1.wav"])
print(asr_result)

#指定正确答案
s1 = "请检测出果皮"
#识别结果
s2 = " ".join(asr_result)

#计算字错率cer
print("字错率:{}".format(metrics.calculate_cer(s1,s2)))
#计算准确率
print("准确率:{}".format(1-metrics.calculate_cer(s1,s2)))
感谢您对我的支持,让我继续努力分享有用的技术与知识点!