Whisper大模型学习记录:自己写代码微调大模型

打印 上一主题 下一主题

主题 862|帖子 862|积分 2586

一、总体序言

1.whisper处置惩罚环节先容

颠末上一次的example模型跑通,我们大致了解到whisper模型调用过程和语音识别具体环节。whisper模型语音识别的具体环节和whisper模型的调用可以拆解为以下图片:


首先我们来先容左手边灰色对话框里的内容:这一部分先容了whisper可以输入的音频类型和根据输入类型可以完成的任务,例如:语音识别笔墨、语种互相翻译、噪音识别等

接下来看右边的流程图,流程图有两大部分组成:encoder部分(编码部分)和decoder部分(解码部分)。

在编码和解码部分分别有两个输入,在解码的最后为输出部分。在编码部分的输入,重要输入将频率标度转换为梅尔标度的频谱和位置编码的音频文件,;而解码部分重要输入多任务训练情势中的tokens。(这一部分将在报告tokenizer时详细讲到)

在encoder和decoder部分是由一层层的处置惩罚层叠起来的;每一个小的处置惩罚层都由(MLP和atteintion)组合而成。MLP指的是多层感知机制,用于担当特性和整合前一层带来的信息;atteion是注意力机制,用于动态衡量各个输入信息之间的关联,调整加权。¶

最后的输出是对下一个单位的预测

在whisper的模仿运行中,我们重要打仗到了三个重要代码文件:训练代码文件(train.py)、运行参数代码文件(run.sh)、whisper源代码文件(src/transformers/models/whisper)下面举行详细解说

二、代码详细解说

(一) run.sh 参数解析

  1. python train.py \
  2.         --model_name_or_path="/mnt/e/王嘟嘟/wsl/asr_large_model/whisper_model/whisper-tiny" \ #调用的模型在哪个路径下?
  3.         --dataset_name="mozilla-foundation/common_voice_11_0" \ #调用的数据库叫什么?
  4.         --dataset_config_name="hi" \  #dataset config 指的是数据集配置。
  5. 它通常涉及到对用于训练或处理的数据集的各种参数和设置的描述。这可能包括数据集的位置、数据的格式、数据的分割方式(如训练集、验证集、测试集的划分比例)、数据的预处理步骤(如数据清洗、转换)等相关信息的定义和配置。
  6. 通过明确数据集配置,可以确保模型或算法能够正确地读取、处理和利用给定的数据集。
  7.         --language="hindi" \  #语言选择
  8.         --train_split_name="test" \  #训练集名称
  9.         --eval_split_name="test" \  #验证集名称
  10.         --max_steps="5000" \ #最大迭代步数
  11.         --output_dir="./whisper-small-hi" \  #输出路径
  12.         --per_device_train_batch_size="16" \ #每次训练批次大小,即每次训练几个数据
  13.         --gradient_accumulation_steps="2" \ #梯度累积步长
  14.         --per_device_eval_batch_size="16" \ #每次验证批次大小,即每次验证几个数据
  15.         --logging_steps="25" \ #日志步数
  16.         --learning_rate="1e-5" \ #
  17.         --warmup_steps="500" \ #预热步骤调整参数
  18.         --evaluation_strategy="steps" \  #验证策略和步数
  19.         --eval_steps="1000" \
  20.         --save_strategy="steps" \ #保存策略和步数
  21.         --save_steps="1000" \
  22.         --generation_max_length="225" \
  23.         --preprocessing_num_workers="16" \#预处理的工作线程数量
  24.         --length_column_name="input_length" \
  25.         --max_duration_in_seconds="30" \
  26.         --text_column_name="sentence" \ #文本名称
  27.         --freeze_feature_encoder="False" \ # 冻结编码器
  28.         --gradient_checkpointing \
  29.         --group_by_length \
  30.         --overwrite_output_dir \
  31.         --do_train \
  32.         --do_eval \
  33.         --predict_with_generate \
  34.         --num_train_epochs="1"  #epoch的含义是跑几遍数据的意思
  35. "run.sh" 31L, 965B                                                                                                                                              1,17          Top
复制代码
(二) train.py 总流程:读取参数➡️设置日志➡️检测中断位置➡️下载数据➡️下载预训练模型、特性提取、tohenizer➡️语音数据重新采样(同一语音数据的格式)➡️预处置惩罚数据库(audio files as arrays and tokenize the targets)➡️下载评估标准(同一评判标准)➡️ 定义数据收集器(同一数据输入格式)➡️创建单独的语音吸收器➡️初始化训练器➡️训练➡️测试➡️记录训练状态

  1. #!/usr/bin/env python
  2. # coding=utf-8
  3. # Copyright 2021 The HuggingFace Team. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. #     http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """
  17. Fine-tuning the library models for sequence to sequence speech recognition.
  18. """
  19. # You can also adapt this script on your own sequence to sequence speech
  20. # recognition task. Pointers for this are left as comments.
  21. import logging
  22. import os
  23. import sys
  24. import warnings
  25. from dataclasses import dataclass, field
  26. from typing import Any, Dict, List, Optional, Union
  27. import datasets
  28. import evaluate
  29. import torch
  30. from datasets import DatasetDict, load_dataset
  31. import transformers
  32. from transformers import (
  33.     AutoConfig,
  34.     AutoFeatureExtractor,
  35.     AutoModelForSpeechSeq2Seq,
  36.     AutoProcessor,
  37.     AutoTokenizer,
  38.     HfArgumentParser,
  39.     Seq2SeqTrainer,
  40.     Seq2SeqTrainingArguments,
  41.     set_seed,
  42. )
  43. from transformers.trainer_utils import get_last_checkpoint, is_main_process
  44. from transformers.utils import check_min_version, send_example_telemetry
  45. from transformers.utils.versions import require_version
  46. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
  47. check_min_version("4.32.0")
  48. require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
  49. logger = logging.getLogger(__name__)
  50. @dataclass
  51. class ModelArguments:
  52.     """
  53.     Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
  54.     """
  55.     model_name_or_path: str = field(
  56.         metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
  57.     )
  58.     config_name: Optional[str] = field(
  59.         default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
  60.     )
  61.     tokenizer_name: Optional[str] = field(
  62.         default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
  63.     )
  64.     feature_extractor_name: Optional[str] = field(
  65.         default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
  66.     )
  67.     cache_dir: Optional[str] = field(
  68.         default=None,
  69.         metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
  70.     )
  71.     use_fast_tokenizer: bool = field(
  72.         default=True,
  73.         metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
  74.     )
  75.     model_revision: str = field(
  76.         default="main",
  77.         metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
  78.     )
  79.     token: str = field(
  80.         default=None,
  81.         metadata={
  82.             "help": (
  83.                 "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
  84.                 "generated when running `huggingface-cli login` (stored in `~/.huggingface`).
  85.                 "如果在进行远程文件访问时没有指定用于HTTP Bearer授权的令牌,"
  86.                 "那么系统将使用在运行huggingface-cli login命令时生成的令牌。"
  87.                 "这个令牌会被存储在用户的主目录下的.huggingface文件夹中。简单来说,它说明了如何进行身份验证以及令牌的存储位置。"
  88.             )
  89.         },
  90.     )
  91.     use_auth_token: bool = field(
  92.         default=None,
  93.         metadata={
  94.             "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
  95.         },
  96.     )
  97.     trust_remote_code: bool = field(
  98.         default=False,
  99.         metadata={
  100.             "help": (
  101.                 "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
  102.                 "should only be set to `True` for repositories you trust and in which you have read the code, as it will"
  103.                 "execute code present on the Hub on your local machine."
  104.             )
  105.         },
  106.     )
  107.     freeze_feature_encoder: bool = field(
  108.         default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
  109.     )
  110.     freeze_encoder: bool = field(
  111.         default=False, metadata={"help": "Whether to freeze the entire encoder of the seq2seq model."}
  112.     )
  113.     forced_decoder_ids: List[List[int]] = field(
  114.         default=None,
  115.         metadata={
  116.             "help": (
  117.                 "A list of pairs of integers which indicates a mapping from generation indices to token indices "
  118.                 "that will be forced before sampling. For example, [[0, 123]] means the first generated token "
  119.                 "will always be a token of index 123."
  120.                 "[[0, 123]]表示生成的第一个令牌将始终是索引为123的令牌。这种机制可以确保在生成文本时,特定位置的令牌是预先定义的,从而影响生成的内容和结构。"
  121.             )
  122.         },
  123.     )
  124.     suppress_tokens: List[int] = field(
  125.         default=None, metadata={"help": "A list of tokens that will be suppressed at generation."}
  126.     )
  127.     apply_spec_augment: bool = field(
  128.         default=False,
  129.         metadata={
  130.             "help": "Whether to apply *SpecAugment* data augmentation to the input features. This is currently only relevant for Wav2Vec2, HuBERT, WavLM and Whisper models."
  131.         },
  132.     )
  133. @dataclass
  134. class DataTrainingArguments:
  135.     """
  136.     Arguments pertaining to what data we are going to input our model for training and eval.
  137.     """
  138.     dataset_name: str = field(
  139.         default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
  140.     )
  141.     dataset_config_name: Optional[str] = field(
  142.         default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
  143.     )
  144.     overwrite_cache: bool = field(
  145.         default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
  146.     )
  147.     preprocessing_num_workers: Optional[int] = field(
  148.         default=None,
  149.         metadata={"help": "The number of processes to use for the preprocessing."},
  150.     )
  151.     max_train_samples: Optional[int] = field(
  152.         default=None,
  153.         metadata={
  154.             "help": (
  155.                 "For debugging purposes or quicker training, truncate the number of training examples to this "
  156.                 "value if set."#在训练机器学习模型时,可能需要对训练示例的数量进行截断,以便控制数据集的大小或提高训练效率。以下是一些关于如何截断训练示例的要点:
  157.                         #截断的目的:降低计算资源的需求。加快模型训练速度。避免过拟合,特别是在数据量过大时。
  158.                         #实现方法:
  159.                         #选择特定数量的示例:可以直接选择前N个示例进行训练。
  160.                         #随机抽样:从整个数据集中随机选择一定数量的示例,以确保样本的多样性。
  161.                         #基于条件的筛选:根据特定条件(如标签、特征等)筛选出符合条件的示例。
  162.                         #在训练过程中使用截断:
  163.                         #在使用某些库(如Hugging Face的Tokenizers)时,可以通过设置参数来控制输入数据的截断。例如,可以使用return_overflowing_tokens和stride参数来管理截断的方式和效果[2]。
  164.                         #注意事项:
  165.                         #确保截断后的数据集仍然具有代表性,以避免模型学习到偏差。
  166.                         #在截断过程中,考虑到数据的分布和特征,以保持训练的有效性。
  167.                         #通过合理地截断训练示例,可以有效地管理训练过程,提高模型的性能和训练效率。
  168.         },
  169.     )
  170.     max_eval_samples: Optional[int] = field(
  171.         default=None,
  172.         metadata={
  173.             "help": (
  174.                 "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
  175.                 "value if set."
  176.             )
  177.         },
  178.     )
  179.     audio_column_name: str = field(
  180.         default="audio",
  181.         metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
  182.     )
  183.     text_column_name: str = field(
  184.         default="text",
  185.         metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
  186.     )
  187.     max_duration_in_seconds: float = field(
  188.         default=20.0,
  189.         metadata={
  190.             "help": (
  191.                 "Truncate audio files that are longer than `max_duration_in_seconds` seconds to"
  192.                 " 'max_duration_in_seconds`"
  193.             )
  194.         },
  195.     )
  196.     min_duration_in_seconds: float = field(
  197.         default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
  198.     )
  199.     preprocessing_only: bool = field(
  200.         default=False,
  201.         metadata={
  202.             "help": (
  203.                 "Whether to only do data preprocessing and skip training. This is especially useful when data"
  204.                 " preprocessing errors out in distributed training due to timeout. In this case, one should run the"
  205.                 " preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
  206.                 " can consequently be loaded in distributed training" #用于预训练报错导致的训练时间超时
  207.             )
  208.         },
  209.     )
  210.     train_split_name: str = field(
  211.         default="train",
  212.         metadata={
  213.             "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
  214.         },
  215.     )
  216.     eval_split_name: str = field(
  217.         default="test",
  218.         metadata={
  219.             "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
  220.         },
  221.     )
  222.     do_lower_case: bool = field(
  223.         default=True,
  224.         metadata={"help": "Whether the target text should be lower cased."},
  225.     )
  226.     language: str = field(
  227.         default=None,
  228.         metadata={
  229.             "help": (
  230.                 "Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
  231.                 "only. For English speech recognition, it should be set to `None`."
  232.             )
  233.         },
  234.     )
  235.     task: str = field(
  236.         default="transcribe",
  237.         metadata={"help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."},
  238.     )
  239. @dataclass
  240. class DataCollatorSpeechSeq2SeqWithPadding: #数据收集器 用于动态编码输入的数据
  241.     """
  242.     Data collator that will dynamically pad the inputs received.
  243.     Args:
  244.         processor ([`WhisperProcessor`])
  245.             The processor used for processing the data.
  246.         decoder_start_token_id (`int`)  #从哪开始编码或解码
  247.             The begin-of-sentence of the decoder.
  248.         forward_attention_mask (`bool`)  #前向掩碼是否开启
  249.             Whether to return attention_mask.
  250.     """
  251.     processor: Any
  252.     decoder_start_token_id: int
  253.     forward_attention_mask: bool
  254.     def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
  255.         # split inputs and labels since they have to be of different lengths and need
  256.         # different padding methods
  257.         model_input_name = self.processor.model_input_names[0]
  258.         input_features = [{model_input_name: feature[model_input_name]} for feature in features]
  259.         label_features = [{"input_ids": feature["labels"]} for feature in features]
  260.         batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
  261.         if self.forward_attention_mask:
  262.             batch["attention_mask"] = torch.LongTensor([feature["attention_mask"] for feature in features])
  263.         labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
  264.         # replace padding with -100 to ignore loss correctly
  265.         labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
  266.         # if bos token is appended in previous tokenization step,
  267.         # cut bos token here as it's append later anyways
  268.         if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
  269.             labels = labels[:, 1:]
  270.         batch["labels"] = labels
  271.         return batch
  272. def main():
  273.     # 1. Parse input arguments     
  274.     # See all possible arguments in src/transformers/training_args.py
  275.     # or by passing the --help flag to this script.
  276.     # We now keep distinct sets of args, for a cleaner separation of concerns.
  277.     parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
  278. #json的读取和其他格式的读取
  279.     if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
  280.         # If we pass only one argument to the script and it's the path to a json file,
  281.         # let's parse it to get our arguments.
  282.         model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
  283.     else:
  284.         model_args, data_args, training_args = parser.parse_args_into_dataclasses()
  285.     if model_args.use_auth_token is not None:
  286.         warnings.warn("The `use_auth_token` argument is deprecated and will be removed in v4.34.", FutureWarning)
  287.         if model_args.token is not None:
  288.             raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
  289.         model_args.token = model_args.use_auth_token
  290.     # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
  291.     # information sent is the one passed as arguments along with your Python/PyTorch versions.
  292.     #send_example_telemetry("run_speech_recognition_seq2seq", model_args, data_args)通过跟踪示例的使用情况,开发团队可以了解哪些功能被频繁使用,从而优化资源分配和维护工作
  293.     send_example_telemetry("train", model_args, data_args)
  294.     # 2. Setup logging
  295.     logging.basicConfig(
  296.         format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
  297.         datefmt="%m/%d/%Y %H:%M:%S",
  298.         handlers=[logging.StreamHandler(sys.stdout)],
  299.     )
  300.     log_level = training_args.get_process_log_level()
  301.     logger.setLevel(log_level)
  302.     datasets.utils.logging.set_verbosity(log_level)
  303.     transformers.utils.logging.set_verbosity(log_level)
  304.     transformers.utils.logging.enable_default_handler()
  305.     transformers.utils.logging.enable_explicit_format()
  306.     logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
  307.     # Log on each process the small summary:
  308.     logger.warning(
  309.         f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
  310.         f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
  311.     )
  312.     logger.info(f"Training/evaluation parameters {training_args}")
  313.     # Set the verbosity to info of the Transformers logger (on main process only):
  314.     if is_main_process(training_args.local_rank):
  315.         transformers.utils.logging.set_verbosity_info()
  316.     logger.info("Training/evaluation parameters %s", training_args)
  317.     # 3. Detecting last checkpoint and eventually continue from last checkpoint
  318.     last_checkpoint = None
  319.     if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:#输出文件夹为空 做了训练 且输出路径存在
  320.         last_checkpoint = get_last_checkpoint(training_args.output_dir)
  321.         if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
  322.             raise ValueError(
  323.                 f"Output directory ({training_args.output_dir}) already exists and is not empty. "
  324.                 "Use --overwrite_output_dir to overcome."
  325.             )
  326.         elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
  327.             logger.info(
  328.                 f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
  329.                 "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
  330.             )
  331.     ####---###Detecting last checkpoint and eventually continue from last checkpoint" 意味着在深度学习或机器学习的训练过程中,自动识别最近一次保存的模型检查点(checkpoint),并能够从这个检查点开始继续训练。这是一套机制,用于确保训练的连续性和效率。当训练因为任何原因(如计划中断、系统错误、电力故障等)被中断时,通过检测到的最后一个保存的模型状态,训练可以不必从头开始,而是从中断的地方恢复,这样可以节省大量的时间和计算资源。
  332. ### 在实际操作中,这通常涉及到使用特定的库函数(如PyTorch中的torch.load)来加载之前保存的模型参数和状态,然后继续执行训练过程。"
  333.     # Set seed before initializing model.
  334.     set_seed(training_args.seed)
  335. ###这行代码强调了在深度学习或机器学习模型初始化之前设置随机种子的重要性。设置随机种子(seed)是为了确保训练过程中的随机性是可复现的。
  336. ###这意味着每次运行代码时,如果使用相同的种子值,将会得到相同初始化的权重和偏置,进而使得实验结果可重复。
  337. ###这对于调试、比较不同模型设置的效果以及发表研究结果时保持一致性至关重要。在PyTorch中,这通常通过调用torch.manual_seed(training_args.seed)或类似的函数来实现,确保从数据预处理到模型训练的整个流程中生成的随机数序列是一致的。
  338. ###这样,即使在分布式训练或不同时间运行实验时,也能得到一致的初始模型状态和实验环境。
  339.    
  340.     # 4. Load dataset
  341.     raw_datasets = DatasetDict()
  342.     if training_args.do_train:
  343.         raw_datasets["train"] = load_dataset(
  344.             data_args.dataset_name,
  345.             data_args.dataset_config_name,
  346.             split=data_args.train_split_name,
  347.             cache_dir=model_args.cache_dir,
  348.             token=model_args.token,
  349.         )
  350.     if training_args.do_eval:
  351.         raw_datasets["eval"] = load_dataset(
  352.             data_args.dataset_name,
  353.             data_args.dataset_config_name,
  354.             split=data_args.eval_split_name,
  355.             cache_dir=model_args.cache_dir,
  356.             token=model_args.token,
  357.         )
  358. #拿到数据当中 音频的路径(path)音频的采样点(array)sentence(文本)以及一些数据——可以理解为表格
  359.     if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
  360.         raise ValueError(
  361.             f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
  362.             "Make sure to set `--audio_column_name` to the correct audio column - one of "
  363.             f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
  364.         )
  365.     if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
  366.         raise ValueError(
  367.             f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
  368.             "Make sure to set `--text_column_name` to the correct text column - one of "
  369.             f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
  370.         )
  371.     # 5. Load pretrained model, tokenizer, and feature extractor(特征提取器)
  372.     #从Auto中下载这三个
  373.     # Distributed training:
  374.     # The .from_pretrained methods guarantee that only one local process can concurrently
  375.     config = AutoConfig.from_pretrained(
  376.         model_args.config_name if model_args.config_name else model_args.model_name_or_path,
  377.         cache_dir=model_args.cache_dir,
  378.         revision=model_args.model_revision,
  379.         token=model_args.token,
  380.         trust_remote_code=model_args.trust_remote_code,
  381.     )
  382.     config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
  383.     # SpecAugment for whisper models
  384.     if getattr(config, "model_type", None) == "whisper":
  385.         config.update({"apply_spec_augment": model_args.apply_spec_augment})
  386.     feature_extractor = AutoFeatureExtractor.from_pretrained(
  387.         model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
  388.         cache_dir=model_args.cache_dir,
  389.         revision=model_args.model_revision,
  390.         token=model_args.token,
  391.         trust_remote_code=model_args.trust_remote_code,
  392.     )
  393.     tokenizer = AutoTokenizer.from_pretrained(
  394.         model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
  395.         cache_dir=model_args.cache_dir,
  396.         use_fast=model_args.use_fast_tokenizer,
  397.         revision=model_args.model_revision,
  398.         token=model_args.token,
  399.         trust_remote_code=model_args.trust_remote_code,
  400.     )
  401. ###TOKENIZER是怎么支持96种语言的?
  402.     model = AutoModelForSpeechSeq2Seq.from_pretrained(
  403.         model_args.model_name_or_path,
  404.         config=config,
  405.         cache_dir=model_args.cache_dir,
  406.         revision=model_args.model_revision,
  407.         token=model_args.token,
  408.         trust_remote_code=model_args.trust_remote_code,
  409.     )
  410.     if model.config.decoder_start_token_id is None:
  411.         raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
  412.     if model_args.freeze_feature_encoder:
  413.         model.freeze_feature_encoder()
  414.     if model_args.freeze_encoder:
  415.         model.freeze_encoder()
  416.         model.model.encoder.gradient_checkpointing = False
  417.     if data_args.language is not None:
  418.         # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
  419.         tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
  420.     # 6. Resample speech dataset if necessary #对语音数据集进行重新采样。重新采样通常是指将音频数据的采样率调整为模型所需的特定采样率。
  421.     dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
  422.     if dataset_sampling_rate != feature_extractor.sampling_rate:
  423.         raw_datasets = raw_datasets.cast_column(
  424.             data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
  425.         )
  426.     # 7. Preprocessing the datasets.
  427.     # We need to read the audio files as arrays and tokenize the targets.
  428.     max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
  429.     min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
  430.     audio_column_name = data_args.audio_column_name
  431.     num_workers = data_args.preprocessing_num_workers
  432.     text_column_name = data_args.text_column_name
  433.     model_input_name = feature_extractor.model_input_names[0]
  434.     do_lower_case = data_args.do_lower_case
  435.     # if SpecAugment is used for whisper models, return attention_mask to guide the mask along time axis
  436.     forward_attention_mask = (
  437.         getattr(config, "model_type", None) == "whisper"
  438.         and getattr(config, "apply_spec_augment", False)
  439.         and getattr(config, "mask_time_prob", 0) > 0
  440.     )
  441.     if data_args.max_train_samples is not None:
  442.         raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
  443.     if data_args.max_eval_samples is not None:
  444.         raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
  445.     def prepare_dataset(batch):
  446.         # process audio
  447.         sample = batch[audio_column_name]
  448.         inputs = feature_extractor(
  449.             sample["array"], sampling_rate=sample["sampling_rate"], return_attention_mask=forward_attention_mask
  450.         )
  451.         # process audio length
  452.         batch[model_input_name] = inputs.get(model_input_name)[0]
  453.         batch["input_length"] = len(sample["array"])
  454.         if forward_attention_mask:
  455.             batch["attention_mask"] = inputs.get("attention_mask")[0]
  456.         # process targets
  457.         input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
  458.         batch["labels"] = tokenizer(input_str).input_ids
  459.         return batch
  460.     with training_args.main_process_first(desc="dataset map pre-processing"):
  461.         vectorized_datasets = raw_datasets.map(
  462.             prepare_dataset,
  463.             remove_columns=next(iter(raw_datasets.values())).column_names,
  464.             num_proc=data_args.preprocessing_num_workers,
  465.             desc="preprocess train dataset",
  466.         )
  467.     # filter data that is shorter than min_input_length or longer than
  468.     # max_input_length
  469.     def is_audio_in_length_range(length):
  470.         return length > min_input_length and length < max_input_length
  471.     vectorized_datasets = vectorized_datasets.filter(
  472.         is_audio_in_length_range,
  473.         num_proc=num_workers,
  474.         input_columns=["input_length"],
  475.     )
  476.     # for large datasets it is advised to run the preprocessing on a
  477.     # single machine first with `args.preprocessing_only` since there will mostly likely
  478.     # be a timeout when running the script in distributed mode.
  479.     # In a second step `args.preprocessing_only` can then be set to `False` to load the
  480.     # cached dataset
  481.     if data_args.preprocessing_only:
  482.         cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
  483.         logger.info(f"Data preprocessing finished. Files cached at {cache}.")
  484.         return
  485.         
  486.     # 8. Load metric:加载特定的评估指标
  487.     metric = evaluate.load("wer")
  488. #词错误率 WER,即 Word Error Rate)
  489.     def compute_metrics(pred):
  490.         pred_ids = pred.predictions
  491.         pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
  492.         pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
  493.         # we do not want to group tokens when computing the metrics
  494.         label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
  495.         wer = metric.compute(predictions=pred_str, references=label_str)
  496.         return {"wer": wer}
  497.     # 9. Create a single speech processor
  498.     # make sure all processes wait until data is saved
  499.     with training_args.main_process_first():
  500.         # only the main process saves them
  501.         if is_main_process(training_args.local_rank):
  502.             # save feature extractor, tokenizer and config
  503.             feature_extractor.save_pretrained(training_args.output_dir)
  504.             tokenizer.save_pretrained(training_args.output_dir)
  505.             config.save_pretrained(training_args.output_dir)
  506.     processor = AutoProcessor.from_pretrained(training_args.output_dir)
  507. #这个语音处理器后续可以用于对输入的语音或者文本数据进行预处理、特征提取等一系列符合模型要求的操作,
  508. #例如将语音信号转换为适合模型输入的特征表示,或者对文本进行分词、编码等处理,从而为模型的输入做好准备工作。
  509.     # 10. Define data collator:分散的、单个的数据样本整理成适合批量输入到模型中的格式,并且处理不同样本之间长度不一致等问题,通过填充操作使一个批次内的数据在维度等方面达到统一。
  510.     data_collator = DataCollatorSpeechSeq2SeqWithPadding(
  511.         processor=processor,
  512.         decoder_start_token_id=model.config.decoder_start_token_id,
  513.         forward_attention_mask=forward_attention_mask,
  514.     )
  515.     # 11. Initialize Trainer
  516.     trainer = Seq2SeqTrainer(
  517.         model=model,
  518.         args=training_args,
  519.         train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
  520.         eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
  521.         tokenizer=feature_extractor,
  522.         data_collator=data_collator,
  523.         compute_metrics=compute_metrics if training_args.predict_with_generate else None,
  524.     )
  525.     # 12. Training
  526.     if training_args.do_train:
  527.         checkpoint = None
  528.         if training_args.resume_from_checkpoint is not None:
  529.             checkpoint = training_args.resume_from_checkpoint
  530.         elif last_checkpoint is not None:
  531.             checkpoint = last_checkpoint
  532.         train_result = trainer.train(resume_from_checkpoint=checkpoint)
  533.         trainer.save_model()  # Saves the feature extractor too for easy upload
  534.         metrics = train_result.metrics
  535.         max_train_samples = (
  536.             data_args.max_train_samples
  537.             if data_args.max_train_samples is not None
  538.             else len(vectorized_datasets["train"])
  539.         )
  540.         metrics["train_samples"] = min(max_train_samples, len(vectorized_datasets["train"]))
  541.         trainer.log_metrics("train", metrics)
  542.         trainer.save_metrics("train", metrics)
  543.         trainer.save_state()
  544.     # 13. Evaluation
  545.     results = {}
  546.     if training_args.do_eval:
  547.         logger.info("*** Evaluate ***")
  548.         metrics = trainer.evaluate(
  549.             metric_key_prefix="eval",
  550.             max_length=training_args.generation_max_length,
  551.             num_beams=training_args.generation_num_beams,
  552.         )
  553.         max_eval_samples = (
  554.             data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])
  555.         )
  556.         metrics["eval_samples"] = min(max_eval_samples, len(vectorized_datasets["eval"]))
  557.         trainer.log_metrics("eval", metrics)
  558.         trainer.save_metrics("eval", metrics)
  559.     # 14. Write Training Stats
  560.     kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "automatic-speech-recognition"}
  561.     if data_args.dataset_name is not None:
  562.         kwargs["dataset_tags"] = data_args.dataset_name
  563.         if data_args.dataset_config_name is not None:
  564.             kwargs["dataset_args"] = data_args.dataset_config_name
  565.             kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
  566.         else:
  567.             kwargs["dataset"] = data_args.dataset_name
  568.     if training_args.push_to_hub:
  569.         trainer.push_to_hub(**kwargs)
  570.     else:
  571.         trainer.create_model_card(**kwargs)
  572.     return results
  573. if __name__ == "__main__":
  574.     main()
复制代码
whisper中的tokennizer是从Autotokenizer内里下载过来的“whisper tokenizer”

那么这一部分有什么特别之处呢?又为什么得从autotokenizer内里下载呢?

我们先来回复第二个问题:Autotokenizer是什么?

可以理解为transformer的一个类,这个类里都是用来处置惩罚token的方法,只不过他们针对的任务有所不同,whispertokenizer是用来专门处置惩罚语音的。

【知识点】transformer“预训练语言模型”

在当前的自然语言处置惩罚研究中, 为了解决语言数据贫乏 (language data poverty) 的问题, 学者们开始探究小规模语言数据资源下自然语言处置惩罚的可行性问题, 因 提出了 “预训练语言模型” (pre-trained language models)。 如许的语言模型利用规 模的文本语料库数据举行 “预训练” (pre-training), 创建 “预训练语言模型” 然 后利用面向特定任务的小规模语言数据集, 根据迁徙学习的原理举行 调” (fine-tuning), 形成 “卑鄙任务————————引自:冯志伟,丁晓梅.人工智能的发展与大语言模型的对齐[J].语言治理学刊,2024,(01):108-12模型”
1.whispertokenizer是怎么支持96种语言的

原先的wenet等其他识别都是基于BPE(例如字建模、phone建模),但是whisper基于的是字节(BBPE),因此对于电脑来说什么语言都是一样的(Tiktoken)

举例来说:“今每天气真好”➡️“今,天,天,气,真,好”(字建模) “今每天气真好”➡️XXXXXX➡️“今每天气真好”(字节建模)

下面写一个步伐来检察tokenizer

  1. from transformers import (
  2.     AutoConfig,
  3.     AutoFeatureExtractor,
  4.     AutoModelForSpeechSeq2Seq,
  5.     AutoProcessor,
  6.     AutoTokenizer,
  7.     HfArgumentParser,
  8.     Seq2SeqTrainer,
  9.     Seq2SeqTrainingArguments,
  10.     set_seed,
  11. )
  12. tokenizer = AutoTokenizer.from_pretrained("/mnt/e/王嘟嘟/wsl/asr_large_model/whisper_model/whisper-tiny")  #输入你模型的位置
  13. print(tokenizer)
复制代码

赤色框框中是whisper的special token即功能类型:翻译、转录等等;黄色框框是语言的token


2.tokenizer的工作次序:载入tokenizer➡️编码具体语言➡️解码具体语言¶

(1)token是怎么编码的? 将任务和笔墨都转换为对应的数字串

  1. from transformers import (
  2.     AutoConfig,
  3.     AutoFeatureExtractor,
  4.     AutoModelForSpeechSeq2Seq,
  5.     AutoProcessor,
  6.     AutoTokenizer,
  7.     HfArgumentParser,
  8.     Seq2SeqTrainer,
  9.     Seq2SeqTrainingArguments,
  10.     set_seed,
  11. )
  12. #准备文本
  13. in_str="今天天气真好"
  14. #加载tokenizer
  15. tokenizer = AutoTokenizer.from_pretrained("/mnt/e/王嘟嘟/wsl/asr_large_model/whisper_model/whisper-tiny")  #输入你模型的位置
  16. tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
  17. #编码
  18. res = tokenizer(input_str).input_ids
复制代码
得出结果:今每天气真好 --> [50258, 50363, 12074, 6135, 42204, 6303, 2131, 50257]

可以看到这句话已经加载成了一系列表现字节的数字

1️⃣检察input_id到底对应的是token内里的什么?

tokenizer内里有一个函数《_convert_id_to_token》: 

  1. def _convert_id_to_token(self, index):  
  2.         ""  
  3.         "
  4.         Converts an index (integer) in a token (str) using the vocab. Whisper's base tokenizer always decodes O  
  5.         OV
  6.         tokens as "", thus we do not use the `unk_token` he  
  7.         re.
  8.          
  9.          """
  10.         return self.decoder.get(i""), "")
复制代码
2️⃣还有许多特殊的token我们怎么理解呢?

我们可以理解为这些特殊的token是用来给模型提示,告诉他们现在处置惩罚的是什么语言(语言 任务 时间戳等)
def set_prefix_tokens(self, language: str = None, task: str = None, predict_timestamps: bool = None)
(2)tokenizer的解码

  1. def decode(
  2.         self,
  3.         token_ids,
  4.         skip_special_tokens: bool = False,  #是否保留special token
  5.         clean_up_tokenization_spaces: bool = None,
  6.         output_offsets: bool = False,
  7.         time_precision: float = 0.02,
  8.         decode_with_timestamps: bool = False,
  9.         normalize: bool = False,
  10.         basic_normalize: bool = False,
  11.         remove_diacritics: bool = False,
  12.         **kwargs,
  13.     ) -> str:
  14. res_jie= tokenizer.decode(res)
复制代码

四、手敲代码

我们自己手敲代码的过程没有官方给出的训练文档这么复杂,我们的代码可以简化为以下几个环节:加载数据➡️初始化训练器➡️训练➡️测试

(一)加载数据

加载数据又可以拆分为加载数据库、加载特性提取器、加载tokenizer、加载processor,这里我们没有像train.py那样拆开步骤,而是继续了一个巨大的IterableDataset来讲这几样东西一次性准备好。

  1. import transformers
  2. from torch.utils.data import IterableDataset
  3. from tqdm import tqdm
  4. from transformers import WhisperFeatureExtractor
  5. import torchaudio
  6. import torch
  7. class IterWhisperDataset(IterableDataset):
  8.     def __init__(wave_scp,text,whisper_feature_extractor):
  9.         pass
  10.     def __len___(self):
  11.         pass
  12.     def __iter__(self): #传入模型以遍历
  13.         pass
复制代码
这一类分为三大方法,第一大方法:__init__初始化板块,负责担当需要加载的对象做为参数,而且举行开端加工,将对象变为whisper便于识别的字典模式。第二大方法__len__获取长度板块。第三大板块__iter__遍历参数出场利用板块,这一板块重要运用初始化完成的数据和函数,对所有的对象举行遍历并运用函数举行处置惩罚。

(二)初始化训练器

1.__init__方法:在这一方法中我们重要对担当的参数和函数举行初始化。

(1)对传入进来的wav.scp和对应的text文本举行处置惩罚,让其变为whisper更轻易担当的字典情势.{id:[wav_path,text],id:[wav_path,text],id:[wav_path,text]}

【Tips】部分小伙伴只有wav格式的音频而没有wav.scp,这里分享一下师弟写的步伐,帮助大家转换为wav.scp格式

  1. import os
  2. def save_data(data,filename):
  3.     #保存文件函数
  4.     with open(os.path.join(save_path,filename),'w',encoding='utf-8') as f:
  5.         for i in data:
  6.             f.writelines(i[0]+' '+i[1]+'\n')
  7.         print("%s Saving succeeded!" % filename)
  8.    
  9. def get_wav_scp():
  10.     #用于生成wav.scp
  11.     wav_scp=[]
  12.     #遍历音频
  13.     for file_name in os.listdir(data_path):
  14.         #判断后缀是否为wav
  15.         if file_name[-3:] == 'wav':
  16.             wav_scp.append([file_name.split(".")[0],os.path.join(data_path,file_name)])
  17.         
  18.     save_data(wav_scp,"wav.scp")
  19. if __name__ == "__main__":
  20.     #数据文件存放路径
  21.     data_path ="/mnt/e/王嘟嘟/wsl/asr_large_model/train_model/diyproject/data/audio"
  22.     #保存的路径
  23.     save_path = "/mnt/e/王嘟嘟/wsl/asr_large_model/train_model/diyproject/data"
  24.     #生成wav.scp
  25.     get_wav_scp()
复制代码
【提取字典】+【传入函数】:传入的函数都是需要在前面加上self.举行绑定,以保证可以在每个方法调用到

  1. class IterWhisperDataset(IterableDataset):
  2.     def __init__(wave_scp,text,whisper_feature_extractor):
  3.         #处理为字典
  4.         self.data_list={}
  5.         #拿到wave的id和路径
  6.         with open(wave_scp,"r",encoding="utf-8") as file:
  7.             for line in tqdm(file.readlines()):
  8.                 line=line.strip()
  9.                 idx=line.split(" ")[0]
  10.                 wav_path=" ".join(line.split(" ")[1:])
  11.                 self.data_list[idx]=[]
  12.                 self.data_list[idx].append(wav_path)
  13.                 pass
  14.             pass
  15.         #拿到text
  16.         with open(text,"r",encoding="utf-8") as file:
  17.             for line in tqdm(file.readlines()):
  18.                 line=line.strip()
  19.                 idx=line.split(" ")[0]
  20.                 text=" ".join(line.split(" ")[1:])
  21.                 self.data_list[idx].append(text)
  22.                 pass
  23.             pass
  24.         self.whisper_feature_extractor=whisper_feature_extractor  #传入特征提取器
  25.         print("文本全部个数为:",len(self.data_list))
  26.         pass
复制代码
2.__len__模块

  1. def __len__(self):
  2.         return len(self.data_list)
复制代码
3.__iter___模块:在这一方法中,我们对函数中的对象举行遍历,然后开始举行操纵,比如:预处置惩罚

  1. def __iter__(self):
  2.         #遍历我们的所有数据
  3.         for idx in self.data_list:
  4.             #音频的路径
  5.             wav_path = self.data_list[idx][0]
  6.             #音频的文本
  7.             text = self.data_list[idx][1]
  8.             
  9.             example = {}
  10.             #提取特征
  11.             data_audio = torchaudio.load(wav_path)
  12.             example['input_features'] = self.whisper_feature_extractor(data_audio[0].numpy(),sampling_rate=16000).input_features[0]
  13.             #token
  14.             example['labels'] = self.whisper_tokenizer(text).input_ids[1:]
  15.             # res_jie=self.whisper_tokenizer.decode(example['labels'])
  16.             # print("----解码---->",res_jie)
  17.             yield example
  18.          
  19.          
  20.             pass
  21.         pass
  22.     pass
复制代码
【知识点】特性提取器的工作原理:获取采样率(torchaudio.load)➡️一维数据阵➡️调用特性提取

  1. whisper_feature_extractor=WhisperFeatureExtractor()
复制代码
在这里传入的特性提取器是泉源于我们之前解说过的Auto中获取的,是一个已经定义好了大类,直接传入实例我们就可以利用了。

获取采样率

  1. data_audio=torchaudio.load("/mnt/e/王嘟嘟/wsl/asr_large_model/train_model/diyproject/data/audio/BAC009S0150W0009.wav")#load(path)
复制代码
表现: 
  1. (whisper) root@小徐的板子:/mnt/e/王嘟嘟/wsl/asr_large_model/train_model/diyproject# python3 audio_train_whisper_small.py
  2. (tensor([[-2.7466e-04, -4.2725e-04, -3.6621e-04,  ...,  3.0518e-05,
  3.           3.0518e-05,  2.1362e-04]]), 16000)
复制代码
提取特性

  1. print(whisper_feature_extractor(data_audio[0].numpy(),sampling_rate=16000))
复制代码
表现:
(whisper) root@小徐的板子:/mnt/e/王嘟嘟/wsl/asr_large_model/train_model/diyproject# python3 audio_train_whisper_small.py
{'input_features': [array([[-0.11046922,  0.2726825 ,  0.1670869 , ..., -1.0561461 ,
        -1.0561461 , -1.0561461 ],
       [-0.17119169,  0.14174527, -0.10604203, ..., -1.0561461 ,
        -1.0561461 , -1.0561461 ],
       [-0.37666786,  0.0177812 , -0.22535133, ..., -1.0561461 ,
        -1.0561461 , -1.0561461 ],
       ...,
       [-0.7373694 , -0.7566987 , -0.7817011 , ..., -1.0561461 ,
        -1.0561461 , -1.0561461 ],
       [-0.7835511 , -0.8563596 , -0.7267041 , ..., -1.0561461 ,
        -1.0561461 , -1.0561461 ],
       [-0.87479997, -1.0272436 , -0.9189577 , ..., -1.0561461 ,
        -1.0561461 , -1.0561461 ]], dtype=float32)]}

 (三)训练模型

  1. import transformersfrom torch.utils.data import IterableDatasetfrom tqdm import tqdmfrom transformers import (    WhisperFeatureExtractor,    WhisperTokenizer,    AutoProcessor,    Seq2SeqTrainer,    Seq2SeqTrainingArguments,    WhisperForConditionalGeneration)import torchaudioimport torchfrom dataclasses import dataclass, fieldfrom typing import Any, Dict, List, Optional, Unionimport torch@dataclassclass DataCollatorSpeechSeq2SeqWithPadding:    """    Data collator that will dynamically pad the inputs received.    Args:        processor ([`WhisperProcessor`])            The processor used for processing the data.        decoder_start_token_id (`int`)            The begin-of-sentence of the decoder.        forward_attention_mask (`bool`)            Whether to return attention_mask.    """    processor: Any    decoder_start_token_id: int    forward_attention_mask: bool    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:        # split inputs and labels since they have to be of different lengths and need        # different padding methods        model_input_name = self.processor.model_input_names[0]        input_features = [{model_input_name: feature[model_input_name]} for feature in features]        label_features = [{"input_ids": feature["labels"]} for feature in features]        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")        if self.forward_attention_mask:            batch["attention_mask"] = torch.LongTensor([feature["attention_mask"] for feature in features])        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")        # replace padding with -100 to ignore loss correctly        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)        # if bos token is appended in previous tokenization step,        # cut bos token here as it's append later anyways        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():            labels = labels[:, 1:]        batch["labels"] = labels        return batch#准备数据class IterWhisperDataset(IterableDataset):    def __init__(self,wav_scp,text,whisper_feature_extractor,whisper_tokenizer):             #处置惩罚为字典        self.data_list={}             #音频路径        with open(wav_scp,"r",encoding="utf-8") as file:            for line in tqdm(file.readlines()):                line = line.strip()                idx = line.split(" ")[0]                wav_path = " ".join(line.split(" ")[1:])                self.data_list[idx] = []                self.data_list[idx].append(wav_path)                pass            pass        pass               #音频文本         with open(text,"r",encoding="utf-8") as file:            for line in tqdm(file.readlines()):                line = line.strip()                idx = line.split(" ")[0]                text = " ".join(line.split(" ")[1:])                self.data_list[idx].append(text)                pass            pass        self.whisper_feature_extractor = whisper_feature_extractor        self.whisper_tokenizer = whisper_tokenizer        print("文本个数为:",len(self.data_list))        pass       #文本个数    def __len__(self):
  2.         return len(self.data_list)    #传入模型遍历       def __iter__(self):
  3.         #遍历我们的所有数据
  4.         for idx in self.data_list:
  5.             #音频的路径
  6.             wav_path = self.data_list[idx][0]
  7.             #音频的文本
  8.             text = self.data_list[idx][1]
  9.             
  10.             example = {}
  11.             #提取特征
  12.             data_audio = torchaudio.load(wav_path)
  13.             example['input_features'] = self.whisper_feature_extractor(data_audio[0].numpy(),sampling_rate=16000).input_features[0]
  14.             #token
  15.             example['labels'] = self.whisper_tokenizer(text).input_ids[1:]
  16.             # res_jie=self.whisper_tokenizer.decode(example['labels'])
  17.             # print("----解码---->",res_jie)
  18.             yield example
  19.          
  20.          
  21.             pass
  22.         pass
  23.     pass#基本路径  whisper_model="/mnt/e/王嘟嘟/wsl/asr_large_model/whisper_model/whisper-tiny"train_wav_scp="/mnt/e/王嘟嘟/wsl/asr_large_model/train_model/diyproject/data/wav.scp"train_text="/mnt/e/王嘟嘟/wsl/asr_large_model/train_model/diyproject/data/text.txt"#特性提取whisper_feature_extractor=WhisperFeatureExtractor.from_pretrained(whisper_model)#tokenwhisper_tokenizer=WhisperTokenizer.from_pretrained(whisper_model)whisper_tokenizer.set_prefix_tokens(language = "chinese", task = "transcribe")#处置惩罚数据完成train_data_list = IterWhisperDataset(    train_wav_scp,    train_text,    whisper_feature_extractor,    whisper_tokenizer)#加载资源model= WhisperForConditionalGeneration.from_pretrained(whisper_model)processor = AutoProcessor.from_pretrained(whisper_model)#初始化训练器data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor,decoder_start_token_id=model.config.decoder_start_token_id,forward_attention_mask=False,)# def compute_metrics(pred):#         pred_ids = pred.predictions#         pred.label_ids[pred.label_ids == -100] = whisper_tokenizer.pad_token_id#         pred_str = whisper_tokenizer.batch_decode(pred_ids, skip_special_tokens=True)#         # we do not want to group tokens when computing the metrics#         label_str = whisper_tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)#         wer = metric.compute(predictions=pred_str, references=label_str)#         return {"wer": wer}training_args = Seq2SeqTrainingArguments(    output_dir="model/v1",  # change to a repo name of your choice    per_device_train_batch_size=1,    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size    learning_rate=0.001,    warmup_steps=50,    num_train_epochs=1,    evaluation_strategy="epoch",    fp16=False,    per_device_eval_batch_size=2,    generation_max_length=128,    logging_steps=4, #迭代多少轮打一次日志    remove_unused_columns=False,  # required as the PeftModel forward doesn't have the signature of the wrapped model's forward    label_names=["labels"],  # same reason as above)trainer = Seq2SeqTrainer(    model=model,    args=training_args,    train_dataset=train_data_list,    eval_dataset=train_data_list,    tokenizer=whisper_feature_extractor,    data_collator=data_collator,    #compute_metrics=compute_metrics if training_args.predict_with_generate else None,    )#训练train_result = trainer.train()trainer.save_model()  #保存训练#测试
复制代码
(四)模型的测评

  1. import numpy as npfrom torch.utils.data import IterableDatasetfrom tqdm import tqdm#评估模型所需要的函数from transformers import (    WhisperFeatureExtractor,    WhisperTokenizer,    WhisperForConditionalGeneration,)from dataclasses import dataclass,fieldfrom typing import Any,Dict,List,Optional,Unionimport torchimport torchaudioclass IterWhisperDataset(IterableDataset):    def __init__(self,wav_scp,text,whisper_feature_extractor,whisper_tokenizer):             #处置惩罚为字典        self.data_list={}             #音频路径        with open(wav_scp,"r",encoding="utf-8") as file:            for line in tqdm(file.readlines()):                line = line.strip()                idx = line.split(" ")[0]                wav_path = " ".join(line.split(" ")[1:])                self.data_list[idx] = []                self.data_list[idx].append(wav_path)                pass            pass        pass               #音频文本         with open(text,"r",encoding="utf-8") as file:            for line in tqdm(file.readlines()):                line = line.strip()                idx = line.split(" ")[0]                text = " ".join(line.split(" ")[1:])                self.data_list[idx].append(text)                pass            pass        self.whisper_feature_extractor = whisper_feature_extractor        self.whisper_tokenizer = whisper_tokenizer        print("文本个数为:",len(self.data_list))        pass       #文本个数    def __len__(self):
  2.         return len(self.data_list)    #传入模型遍历       def __iter__(self):        #遍历我们的所有数据        for idx in self.data_list:            #音频的路径            wav_path = self.data_list[idx][0]            #音频的文本            text = self.data_list[idx][1]                        example = {}            example['idx'] = idx #给idx赋值,传递到后面的解码过程,看是否对应            #提取特性            data_audio = torchaudio.load(wav_path)            example['input_features'] = self.whisper_feature_extractor(data_audio[0].numpy(),sampling_rate=16000).input_features[0]            #token            example['labels'] = self.whisper_tokenizer(text).input_ids            #res_jie = self.whisper_tokenizer.decode(example["labels"],skip_special_tokens = False)                        #print('解码----->',res_jie)            #print(example["labels"])            #print(self.whisper_tokenizer.batch_decode(example["labels"]))            yield example                              pass        pass    passwhisper_model='model/v1' #传入自己训练好的模型举行评估out_file = open("./result",'w',encoding='utf-8') #设置一个输出路径train_wav_scp = "/mnt/e/王嘟嘟/wsl/asr_large_model/train_model/diyproject/data/wav.scp"train_text = "/mnt/e/王嘟嘟/wsl/asr_large_model/train_model/diyproject/data/text.txt"#特性提取whisper_feature_extractor = WhisperFeatureExtractor.from_pretrained(whisper_model)#token#自己训练好的没有token,因此需要借助官方模型whisper_tokenizer = WhisperTokenizer.from_pretrained("/mnt/e/王嘟嘟/wsl/asr_large_model/whisper_model/whisper-tiny",language='chinese',task='transcribe')#处置惩罚数据train_data_list = IterWhisperDataset(    train_wav_scp,    train_text,    whisper_feature_extractor,    whisper_tokenizer)#引入预训练模型model = WhisperForConditionalGeneration.from_pretrained(whisper_model)#eval_dataloader = DataLoader(common_voice["test"], batch_size=8, collate_fn=data_collator)model.eval() #将模型设置为评估模式for step, batch in enumerate(tqdm(train_data_list)):    #print(step,batch)    with torch.cuda.amp.autocast():        with torch.no_grad():            generated_tokens = (                model.generate(                    #多次迭代模型特性并传入generated_tokens                    input_features=torch.from_numpy(batch["input_features"][np.newaxis,:,:]),                    #解码模型中训练集的id                    decoder_input_ids=torch.from_numpy(np.array([batch["labels"][:4]])),                    #设置长度                    max_new_tokens=255,                )                .cpu()                .numpy()            )            labels = batch["labels"]            labels = np.where(labels != -100, labels, whisper_tokenizer.pad_token_id)            #解码文本            decoded_preds = whisper_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)            decoded_labels = whisper_tokenizer.batch_decode(labels, skip_special_tokens=True)            out_file.write(batch['idx']+' '+decoded_preds[0]+'\n')            pass        pass    del generated_tokens, labels, batch
复制代码


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

罪恶克星

金牌会员
这个人很懒什么都没写!
快速回复 返回顶部 返回列表