ToB企服应用市场:ToB评测及商务社交产业平台

标题: llama-factory源码详解——以DPO为例 [打印本页]

作者: 小小小幸运    时间: 2024-8-31 02:06
标题: llama-factory源码详解——以DPO为例
本文记载了我在学习 llama-factory过程中对代码运行过程的梳理
代码入口——src/train.py

  1. from llamafactory.train.tuner import run_exp
  2. def main():
  3.     run_exp()
  4. def _mp_fn(index):
  5.     # For xla_spawn (TPUs)
  6.     run_exp()
  7. if __name__ == "__main__":
  8.     main()
复制代码
run_exp() 

该函数位于src/llamafactory/train/tuner.py
  1. def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
  2.     callbacks.append(LogCallback())
  3.     model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
  4.     if finetuning_args.stage == "pt":
  5.         run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
  6.     elif finetuning_args.stage == "sft":
  7.         run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
  8.     elif finetuning_args.stage == "rm":
  9.         run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
  10.     elif finetuning_args.stage == "ppo":
  11.         run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
  12.     elif finetuning_args.stage == "dpo":
  13.         run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
  14.     elif finetuning_args.stage == "kto":
  15.         run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
  16.     else:
  17.         raise ValueError("Unknown task: {}.".format(finetuning_args.stage))
复制代码
这段代码首先获取训练参数,也就是 get_train_args函数,位置在src/llamafactory/hparams/parser.py,hparams这个包里还有llama-factory的参数体系。
这个函数是用来剖析和验证训练参数的。它接受一个可选的字典参数args,如果没有提供,则利用默认值。函数的重要使命是:
最后,函数返回五个参数对象,分别代表模子、数据、训练、微调和生成的相关参数。这些参数将被用来配置和实行训练过程。
run_dpo()

以dpo为例,run_dpo()函数在src/llamafactory/train/dpo/workflow.py。
导入模块

  1. from typing import TYPE_CHECKING, List, Optional
  2. from ...data import PairwiseDataCollatorWithPadding, get_dataset
  3. from ...extras.constants import IGNORE_INDEX
  4. from ...extras.ploting import plot_loss
  5. from ...hparams import ModelArguments
  6. from ...model import load_model, load_tokenizer
  7. from ..trainer_utils import create_modelcard_and_push, create_ref_model
  8. from .trainer import CustomDPOTrainer
复制代码
这些导入语句引入了所需的模块和函数:

范例检查

  1. if TYPE_CHECKING:
  2.     from transformers import Seq2SeqTrainingArguments, TrainerCallback
  3.     from ...hparams import DataArguments, FinetuningArguments
复制代码
这些导入语句仅在范例检查时利用,用于定义范例提示。
run_dpo 函数

  1. def run_dpo(
  2.     model_args: "ModelArguments",
  3.     data_args: "DataArguments",
  4.     training_args: "Seq2SeqTrainingArguments",
  5.     finetuning_args: "FinetuningArguments",
  6.     callbacks: Optional[List["TrainerCallback"]] = None,
  7. ):
复制代码
这个函数 run_dpo 接受五个参数:

加载分词器和数据集

  1.     tokenizer_module = load_tokenizer(model_args)
  2.     tokenizer = tokenizer_module["tokenizer"]
  3.     dataset_module = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
  4.     model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
复制代码
这几行代码实行以下利用:
创建数据整理器

  1.     data_collator = PairwiseDataCollatorWithPadding(
  2.         tokenizer=tokenizer,
  3.         pad_to_multiple_of=8,
  4.         label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
  5.     )
复制代码
这段代码创建了一个 PairwiseDataCollatorWithPadding 对象,用于数据整理。它利用分词器,并设置填充到8的倍数。如果 data_args.ignore_pad_token_for_loss 为真,则利用 IGNORE_INDEX 作为标签填充标记,否则利用分词器的填充标记。
创建参考模子

  1.     if finetuning_args.use_ref_model:
  2.         if finetuning_args.ref_model is None and (not training_args.do_train):  # use the model itself
  3.             ref_model = model
  4.         else:
  5.             ref_model = create_ref_model(model_args, finetuning_args)
  6.     else:
  7.         ref_model = None
复制代码
这段代码创建一个参考模子:

更新训练参数

  1.     training_args.remove_unused_columns = False  # important for pairwise dataset
复制代码
这行代码更新训练参数,设置 remove_unused_columns 为 False,这对于成对数据集非常重要。
初始化训练器

  1.     trainer = CustomDPOTrainer(
  2.         model=model,
  3.         ref_model=ref_model,
  4.         args=training_args,
  5.         finetuning_args=finetuning_args,
  6.         data_collator=data_collator,
  7.         callbacks=callbacks,
  8.         **dataset_module,
  9.         **tokenizer_module,
  10.     )
复制代码
这段代码初始化自定义的训练器 CustomDPOTrainer,传入模子、参考模子、训练参数、微调参数、数据整理器、回调函数以及数据集和分词器模块中的其他参数。
训练模子

  1.     if training_args.do_train:
  2.         train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
  3.         trainer.save_model()
  4.         trainer.log_metrics("train", train_result.metrics)
  5.         trainer.save_metrics("train", train_result.metrics)
  6.         trainer.save_state()
  7.         if trainer.is_world_process_zero() and finetuning_args.plot_loss:
  8.             plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/accuracies"])
复制代码
如果 do_train 为真,则实行以下利用:
评估模子

  1.     if training_args.do_eval:
  2.         metrics = trainer.evaluate(metric_key_prefix="eval")
  3.         if id(model) == id(ref_model):  # unable to compute rewards if reference model is the model itself
  4.             remove_keys = [key for key in metrics.keys() if "rewards" in key]
  5.             for key in remove_keys:
  6.                 metrics.pop(key)
  7.         trainer.log_metrics("eval", metrics)
  8.         trainer.save_metrics("eval", metrics)
复制代码
如果 do_eval 为真,则实行以下利用:
创建模子卡并推送

  1.     create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
复制代码
最后,调用 create_modelcard_and_push 函数,创建模子卡并推送到指定位置。
总结

这个代码片段定义了一个 run_dpo 函数,用于加载和准备模子、数据集和相关的配置参数,初始化自定义训练器 CustomDPOTrainer,并根据必要进行训练和评估。它还包括创建模子卡并推送的步调。
CustomDPOTrainer类

  1. class CustomDPOTrainer(DPOTrainer):
复制代码
这个类 CustomDPOTrainer 继承自 DPOTrainer,它是一个自定义的训练器类。
  1.     def __init__(
  2.         self,
  3.         model: Union["PreTrainedModel", torch.nn.Module],
  4.         ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]],
  5.         finetuning_args: "FinetuningArguments",
  6.         processor: Optional["ProcessorMixin"],
  7.         disable_dropout: bool = True,
  8.         **kwargs,
  9.     ):
复制代码
初始化方法,接受以下参数:

  1.         if disable_dropout:
  2.             disable_dropout_in_model(model)
  3.             if ref_model is not None:
  4.                 disable_dropout_in_model(ref_model)
复制代码
如果 disable_dropout 为真,则禁用模子和参考模子中的 dropout。
  1.         self.finetuning_args = finetuning_args
  2.         self.f_divergence_type = "reverse_kl"
  3.         self.reference_free = False
  4.         self.use_dpo_data_collator = True  # hack to avoid warning
  5.         self.generate_during_eval = False  # disable at evaluation
  6.         self.label_pad_token_id = IGNORE_INDEX
  7.         self.padding_value = 0
  8.         self.is_encoder_decoder = model.config.is_encoder_decoder
  9.         self.precompute_ref_log_probs = False
  10.         self._precomputed_train_ref_log_probs = False
  11.         self._precomputed_eval_ref_log_probs = False
  12.         self._peft_has_been_casted_to_bf16 = False
复制代码
初始化一些实例变量,包括微调参数、散度范例、是否利用参考模子、数据整理器、评估期间是否生成、标签填充标记、填充值、是否是编码器-解码器模子等。
  1.         self.ref_model = ref_model
  2.         self._stored_metrics = defaultdict(lambda: defaultdict(list))
复制代码
设置参考模子,并初始化一个存储指标的字典。
  1.         # dpo hyperparams
  2.         self.beta = finetuning_args.pref_beta
  3.         self.loss_type = finetuning_args.pref_loss
  4.         self.ftx_gamma = finetuning_args.pref_ftx
  5.         self.label_smoothing = finetuning_args.dpo_label_smoothing
  6.         self.simpo_gamma = finetuning_args.simpo_gamma
复制代码
初始化一些 DPO(偏好优化)超参数。
  1.         Trainer.__init__(self, model=model, **kwargs)
  2.         if not hasattr(self, "accelerator"):
  3.             raise AttributeError("Please update `transformers`.")
复制代码
调用父类 Trainer 的初始化方法,并检查是否存在 accelerator 属性。
  1.         warnings.simplefilter("ignore")  # remove gc warnings on ref model
复制代码
忽略一些警告信息。
  1.         if ref_model is not None:
  2.             if self.is_deepspeed_enabled:
  3.                 if not (
  4.                     getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
  5.                 ):  # quantized models are already set on the correct device
  6.                     self.ref_model = self._prepare_deepspeed(self.ref_model)
  7.             else:
  8.                 self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
  9.                 self.ref_model.eval()
复制代码
如果参考模子不为空,且启用了 DeepSpeed,则准备 DeepSpeed 模子;否则,利用加速器准备参考模子,并将其设置为评估模式。
  1.         if processor is not None:
  2.             self.add_callback(SaveProcessorCallback(processor))
复制代码
如果处置惩罚器不为空,则添加 SaveProcessorCallback 回调。
  1.         if finetuning_args.pissa_convert:
  2.             self.callback_handler.add_callback(PissaConvertCallback)
复制代码
如果启用了 pissa_convert,则添加 PissaConvertCallback 回调。
  1.         if finetuning_args.use_badam:
  2.             from badam import BAdamCallback, clip_grad_norm_old_version
  3.             self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
  4.             self.add_callback(BAdamCallback)
复制代码
如果启用了 use_badam,则导入 BAdamCallback 和 clip_grad_norm_old_version,并添加 BAdamCallback 回调。
  1.     def create_optimizer(self) -> "torch.optim.Optimizer":
  2.         if self.optimizer is None:
  3.             self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
  4.         return super().create_optimizer()
复制代码
创建优化器,如果优化器为空,则调用 create_custom_optimzer 创建自定义优化器。
  1.     def create_scheduler(
  2.         self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
  3.     ) -> "torch.optim.lr_scheduler.LRScheduler":
  4.         create_custom_scheduler(self.args, num_training_steps, optimizer)
  5.         return super().create_scheduler(num_training_steps, optimizer)
复制代码
创建学习率调治器,调用 create_custom_scheduler 创建自定义调治器。
  1.     def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
  2.         r"""
  3.         Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
  4.         """
  5.         log_odds = (chosen_logps - rejected_logps) - (
  6.             torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
  7.         )
  8.         sft_loss = -chosen_logps
  9.         odds_ratio_loss = -F.logsigmoid(log_odds)
  10.         orpo_loss = sft_loss + self.beta * odds_ratio_loss
  11.         return orpo_loss
复制代码
这是一个用于计算策略模子的批量对数概率的Odds Ratio Policy Optimization(ORPO)损失的PyTorch函数。log_odds 是选择和拒绝的对数概率之差,sft_loss 是负的选择对数概率,odds_ratio_loss 是负的 logsigmoid,终极的 orpo_loss 是 sft_loss 和 odds_ratio_loss 的加权和。ORPO损失旨在鼓励策略选择具有更高期望回报的动作,同时惩罚它选择不太可能是最优的动作。通过均衡这两个目标,ORPO旨在改善强化学习署理在复杂情况中的性能。
  1.     def simpo_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
  2.         r"""
  3.         Computes SimPO loss for batched log probabilities of the policy model.
  4.         """
  5.         pi_logratios = chosen_logps - rejected_logps
  6.         gamma_logratios = self.simpo_gamma / self.beta
  7.         logits = pi_logratios - gamma_logratios
  8.         simpo_loss = -F.logsigmoid(self.beta * logits)
  9.         return simpo_loss
复制代码
这个函数用于计算批量对数概率的策略模子的SimPO(简朴政策优化)损失。pi_logratios 是选择和拒绝的对数概率之差,gamma_logratios 是 simpo_gamma 和 beta 的比值,logits 是 pi_logratios 和 gamma_logratios 之差,终极的 simpo_loss 是负的 logsigmoid。
  1.     def compute_preference_loss(
  2.         self,
  3.         policy_chosen_logps: "torch.Tensor",
  4.         policy_rejected_logps: "torch.Tensor",
  5.         reference_chosen_logps: Optional["torch.Tensor"],
  6.         reference_rejected_logps: Optional["torch.Tensor"],
  7.     ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
  8.         r"""
  9.         Computes loss for preference learning.
  10.         """
  11.         if not self.finetuning_args.use_ref_model:
  12.             if self.loss_type == "orpo":
  13.                 losses = self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
  14.             elif self.loss_type == "simpo":
  15.                 losses = self.simpo_loss(policy_chosen_logps, policy_rejected_logps)
  16.             else:
  17.                 raise NotImplementedError("Unknown loss type: {}.".format(self.loss_type))
  18.             chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach()
  19.             rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
  20.         else:
  21.             losses, chosen_rewards, rejected_rewards = self.dpo_loss(
  22.                 policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
  23.             )
  24.         return losses, chosen_rewards, rejected_rewards
复制代码
计算偏好学习的损失。如果不利用参考模子,根据参数中的损失范例计算 ORPO 或 SimPO 损失,并计算选择和拒绝的嘉奖。如果利用参考模子,调用 dpo_loss 计算损失和嘉奖。对于是否利用参考模子,也就是use_ref_model参数,可以到src/llamafactory/hparams/finetuning_args.py中检察它的默认值:
  1. self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
复制代码
  1. pref_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"] = field(
  2.         default="sigmoid",
  3.         metadata={"help": "The type of DPO loss to use."},
  4.     )
复制代码
 可以看到pref_loss的默认值是sigmod,也就是use_ref_model在dpo阶段默认是True。
  1.     def concatenated_forward(
  2.         self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
  3.     ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
  4.         r"""
  5.         Computes the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
  6.         Otherwise the average log probabilities.
  7.         """
  8.         if self.finetuning_args.use_ref_model:
  9.             batch = {k: v.detach().clone() for k, v in batch.items()}  # avoid error
  10.         all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
  11.         all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
  12.         if self.loss_type in ["ipo", "orpo", "simpo"]:
  13.             all_logps = all_logps / valid_length
  14.         batch_size = batch["input_ids"].size(0) // 2
  15.         chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
  16.         chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
  17.         chosen_length, _ = valid_length.split(batch_size, dim=0)
  18.         return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
复制代码
计算给定 logits 下标签的对数概率之和(如果损失范例不是 IPO、ORPO 或 SimPO),否则计算平均对数概率。返回选择和拒绝的对数概率、logits 和选择的对数概率的平均值。
  1. def compute_reference_log_probs(
  2.         self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
  3.     ) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
  4.         r"""
  5.         Computes log probabilities of the reference model.
  6.         """
  7.         if not self.finetuning_args.use_ref_model:
  8.             return None, None
  9.         if self.ref_model is None:
  10.             ref_model = model
  11.             ref_context = self.accelerator.unwrap_model(model).disable_adapter()
  12.         else:
  13.             ref_model = self.ref_model
  14.             ref_context = nullcontext()
  15.         with torch.no_grad(), ref_context:
  16.             reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch)
  17.         return reference_chosen_logps, reference_rejected_logps
复制代码
计算参考模子的对数概率。如果不利用参考模子,返回 None。否则,计算参考模子的选择和拒绝对数概率。
  1. def get_batch_loss_metrics(
  2.         self,
  3.         model: "PreTrainedModel",
  4.         batch: Dict[str, "torch.Tensor"],
  5.         train_eval: Literal["train", "eval"] = "train",
  6.     ) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
  7.         r"""
  8.         Computes the DPO loss and other metrics for the given batch of inputs for train or test.
  9.         """
  10.         metrics = {}
  11.         (
  12.             policy_chosen_logps,
  13.             policy_rejected_logps,
  14.             policy_chosen_logits,
  15.             policy_rejected_logits,
  16.             policy_chosen_logps_avg,
  17.         ) = self.concatenated_forward(model, batch)
  18.         reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(model, batch)
  19.         losses, chosen_rewards, rejected_rewards = self.compute_preference_loss(
  20.             policy_chosen_logps,
  21.             policy_rejected_logps,
  22.             reference_chosen_logps,
  23.             reference_rejected_logps,
  24.         )
  25.         sft_loss = -policy_chosen_logps_avg
  26.         if self.ftx_gamma > 1e-6:
  27.             losses += self.ftx_gamma * sft_loss
  28.         reward_accuracies = (chosen_rewards > rejected_rewards).float()
  29.         prefix = "eval_" if train_eval == "eval" else ""
  30.         metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.mean().cpu()
  31.         metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.mean().cpu()
  32.         metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.mean().cpu()
  33.         metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).mean().cpu()
  34.         metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().mean().cpu()
  35.         metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().mean().cpu()
  36.         metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().mean().cpu()
  37.         metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().mean().cpu()
  38.         if self.loss_type == "orpo":
  39.             metrics["{}sft_loss".format(prefix)] = sft_loss.detach().mean().cpu()
  40.             metrics["{}odds_ratio_loss".format(prefix)] = ((losses - sft_loss) / self.beta).detach().mean().cpu()
  41.         return losses.mean(), metrics
复制代码
这个方法计算给定输入批次的 DPO 损失和其他指标,用于训练或测试。具体步调如下:
总结: CustomDPOTrainer 类扩展了 DPOTrainer,添加了自定义的初始化方法、优化器和调治器创建方法,以及计算偏好学习损失和其他指标的方法。它还包括计算 ORPO 和 SimPO 损失的方法,以及计算参考模子对数概率的方法。这个类重要用于在训练过程中处置惩罚偏好优化相关的使命。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 ToB企服应用市场:ToB评测及商务社交产业平台 (https://dis.qidao123.com/) Powered by Discuz! X3.4