(二) train.py 总流程:读取参数➡️设置日志➡️检测中断位置➡️下载数据➡️下载预训练模型、特性提取、tohenizer➡️语音数据重新采样(同一语音数据的格式)➡️预处置惩罚数据库(audio files as arrays and tokenize the targets)➡️下载评估标准(同一评判标准)➡️ 定义数据收集器(同一数据输入格式)➡️创建单独的语音吸收器➡️初始化训练器➡️训练➡️测试➡️记录训练状态
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for sequence to sequence speech recognition.
"""
# You can also adapt this script on your own sequence to sequence speech
# recognition task. Pointers for this are left as comments.
import logging
import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
import datasets
import evaluate
import torch
from datasets import DatasetDict, load_dataset
import transformers
from transformers import (
AutoConfig,
AutoFeatureExtractor,
AutoModelForSpeechSeq2Seq,
AutoProcessor,
AutoTokenizer,
HfArgumentParser,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint, is_main_process
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
default=None, metadata={"help": "A list of tokens that will be suppressed at generation."}
)
apply_spec_augment: bool = field(
default=False,
metadata={
"help": "Whether to apply *SpecAugment* data augmentation to the input features. This is currently only relevant for Wav2Vec2, HuBERT, WavLM and Whisper models."
},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name: str = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."#在训练机器学习模型时,可能需要对训练示例的数量进行截断,以便控制数据集的大小或提高训练效率。以下是一些关于如何截断训练示例的要点:
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
####---###Detecting last checkpoint and eventually continue from last checkpoint" 意味着在深度学习或机器学习的训练过程中,自动识别最近一次保存的模型检查点(checkpoint),并能够从这个检查点开始继续训练。这是一套机制,用于确保训练的连续性和效率。当训练因为任何原因(如计划中断、系统错误、电力故障等)被中断时,通过检测到的最后一个保存的模型状态,训练可以不必从头开始,而是从中断的地方恢复,这样可以节省大量的时间和计算资源。
import transformersfrom torch.utils.data import IterableDatasetfrom tqdm import tqdmfrom transformers import ( WhisperFeatureExtractor, WhisperTokenizer, AutoProcessor, Seq2SeqTrainer, Seq2SeqTrainingArguments, WhisperForConditionalGeneration)import torchaudioimport torchfrom dataclasses import dataclass, fieldfrom typing import Any, Dict, List, Optional, Unionimport torch@dataclassclass DataCollatorSpeechSeq2SeqWithPadding: """ Data collator that will dynamically pad the inputs received. Args: processor ([`WhisperProcessor`]) The processor used for processing the data. decoder_start_token_id (`int`) The begin-of-sentence of the decoder. forward_attention_mask (`bool`) Whether to return attention_mask. """ processor: Any decoder_start_token_id: int forward_attention_mask: bool def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: # split inputs and labels since they have to be of different lengths and need # different padding methods model_input_name = self.processor.model_input_names[0] input_features = [{model_input_name: feature[model_input_name]} for feature in features] label_features = [{"input_ids": feature["labels"]} for feature in features] batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") if self.forward_attention_mask: batch["attention_mask"] = torch.LongTensor([feature["attention_mask"] for feature in features]) labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") # replace padding with -100 to ignore loss correctly labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) # if bos token is appended in previous tokenization step, # cut bos token here as it's append later anyways if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item(): labels = labels[:, 1:] batch["labels"] = labels return batch#准备数据class IterWhisperDataset(IterableDataset): def __init__(self,wav_scp,text,whisper_feature_extractor,whisper_tokenizer): #处置惩罚为字典 self.data_list={} #音频路径 with open(wav_scp,"r",encoding="utf-8") as file: for line in tqdm(file.readlines()): line = line.strip() idx = line.split(" ")[0] wav_path = " ".join(line.split(" ")[1:]) self.data_list[idx] = [] self.data_list[idx].append(wav_path) pass pass pass #音频文本 with open(text,"r",encoding="utf-8") as file: for line in tqdm(file.readlines()): line = line.strip() idx = line.split(" ")[0] text = " ".join(line.split(" ")[1:]) self.data_list[idx].append(text) pass pass self.whisper_feature_extractor = whisper_feature_extractor self.whisper_tokenizer = whisper_tokenizer print("文本个数为:",len(self.data_list)) pass #文本个数 def __len__(self):
pass#基本路径 whisper_model="/mnt/e/王嘟嘟/wsl/asr_large_model/whisper_model/whisper-tiny"train_wav_scp="/mnt/e/王嘟嘟/wsl/asr_large_model/train_model/diyproject/data/wav.scp"train_text="/mnt/e/王嘟嘟/wsl/asr_large_model/train_model/diyproject/data/text.txt"#特性提取whisper_feature_extractor=WhisperFeatureExtractor.from_pretrained(whisper_model)#tokenwhisper_tokenizer=WhisperTokenizer.from_pretrained(whisper_model)whisper_tokenizer.set_prefix_tokens(language = "chinese", task = "transcribe")#处置惩罚数据完成train_data_list = IterWhisperDataset( train_wav_scp, train_text, whisper_feature_extractor, whisper_tokenizer)#加载资源model= WhisperForConditionalGeneration.from_pretrained(whisper_model)processor = AutoProcessor.from_pretrained(whisper_model)#初始化训练器data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor,decoder_start_token_id=model.config.decoder_start_token_id,forward_attention_mask=False,)# def compute_metrics(pred):# pred_ids = pred.predictions# pred.label_ids[pred.label_ids == -100] = whisper_tokenizer.pad_token_id# pred_str = whisper_tokenizer.batch_decode(pred_ids, skip_special_tokens=True)# # we do not want to group tokens when computing the metrics# label_str = whisper_tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)# wer = metric.compute(predictions=pred_str, references=label_str)# return {"wer": wer}training_args = Seq2SeqTrainingArguments( output_dir="model/v1", # change to a repo name of your choice per_device_train_batch_size=1, gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size learning_rate=0.001, warmup_steps=50, num_train_epochs=1, evaluation_strategy="epoch", fp16=False, per_device_eval_batch_size=2, generation_max_length=128, logging_steps=4, #迭代多少轮打一次日志 remove_unused_columns=False, # required as the PeftModel forward doesn't have the signature of the wrapped model's forward label_names=["labels"], # same reason as above)trainer = Seq2SeqTrainer( model=model, args=training_args, train_dataset=train_data_list, eval_dataset=train_data_list, tokenizer=whisper_feature_extractor, data_collator=data_collator, #compute_metrics=compute_metrics if training_args.predict_with_generate else None, )#训练train_result = trainer.train()trainer.save_model() #保存训练#测试
复制代码
(四)模型的测评
import numpy as npfrom torch.utils.data import IterableDatasetfrom tqdm import tqdm#评估模型所需要的函数from transformers import ( WhisperFeatureExtractor, WhisperTokenizer, WhisperForConditionalGeneration,)from dataclasses import dataclass,fieldfrom typing import Any,Dict,List,Optional,Unionimport torchimport torchaudioclass IterWhisperDataset(IterableDataset): def __init__(self,wav_scp,text,whisper_feature_extractor,whisper_tokenizer): #处置惩罚为字典 self.data_list={} #音频路径 with open(wav_scp,"r",encoding="utf-8") as file: for line in tqdm(file.readlines()): line = line.strip() idx = line.split(" ")[0] wav_path = " ".join(line.split(" ")[1:]) self.data_list[idx] = [] self.data_list[idx].append(wav_path) pass pass pass #音频文本 with open(text,"r",encoding="utf-8") as file: for line in tqdm(file.readlines()): line = line.strip() idx = line.split(" ")[0] text = " ".join(line.split(" ")[1:]) self.data_list[idx].append(text) pass pass self.whisper_feature_extractor = whisper_feature_extractor self.whisper_tokenizer = whisper_tokenizer print("文本个数为:",len(self.data_list)) pass #文本个数 def __len__(self):