通过微调预练习模子得到本身的模子

[复制链接]
发表于 2026-1-29 17:46:44 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

×
通过微调预练习模子得到本身的模子

目次


  • 简介
  • 环境准备
  • 数据准备
  • 加载预练习模子和Tokenizer
  • 数据预处理惩罚
  • 设置练习参数
  • 初始化Trainer并开始练习
  • 评估和生存模子
  • 总结
简介

在这篇博客中,我们将先容怎样通过预练习模子举行微调来得到本身的模子。我们将使用Hugging Face的Transformers库和一个BART模子举行示例演示。整个过程包罗环境准备、数据准备、模子加载、数据预处理惩罚、练习参数设置、练习、评估和生存模子。
环境准备

起首,我们须要安装须要的Python库:
  1. pip install transformers datasets torch
复制代码
数据准备

假设我们有三个数据集:练习集、验证集和测试集,分别存储在JSON文件中。我们将这些数据集加载到内存中。
  1. import os
  2. from datasets import load_dataset
  3. train_data_name = 'train_data'
  4. valid_data_name = 'valid_data'
  5. test_data_name = 'test_data'
  6. # 顶级数据目录
  7. top_data_dir = '../../data/sql'
  8. raw_data_dir = os.path.join(top_data_dir, 'raw_data/')
  9. train_raw_data_path = os.path.join(raw_data_dir, f'{train_data_name}.json')
  10. valid_raw_data_path = os.path.join(raw_data_dir, f'{valid_data_name}.json')
  11. test_raw_data_path = os.path.join(raw_data_dir, f'{test_data_name}.json')
  12. # 加载JSON数据集,忽略无法解码的字符
  13. dataset = load_dataset('json', data_files={
  14.     'train': train_raw_data_path,
  15.     'validation': valid_raw_data_path,
  16.     'test': test_raw_data_path
  17. })
复制代码
加载预练习模子和Tokenizer

我们将使用Hugging Face的Transformers库加载预练习的BART模子和对应的Tokenizer。
  1. from transformers import AutoTokenizer, BartForConditionalGeneration
  2. tokenizer = AutoTokenizer.from_pretrained("./bart-base")
  3. model = BartForConditionalGeneration.from_pretrained("./bart-base").to(device)
复制代码
数据预处理惩罚

界说数据预处理惩罚函数,将输入和目的文本举行tokenize,并确保长度划一。
  1. def preprocess_function(examples):
  2.     inputs = examples['code']
  3.     targets = examples['text']
  4.     # 使用 `max_length` 和 `padding` 确保一致的长度
  5.     model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding='max_length')
  6.     labels = tokenizer(text_target=targets, max_length=512, truncation=True, padding='max_length')
  7.     model_inputs['labels'] = labels['input_ids']
  8.     return model_inputs
  9. # 应用预处理函数到训练集和验证集
  10. tokenized_datasets = dataset.map(preprocess_function, batched=True)
复制代码
设置练习参数

设置练习参数,包罗输出目次、批量巨细、练习轮数等。
  1. from transformers import TrainingArguments
  2. training_args = TrainingArguments(
  3.     output_dir='./results',          # 输出结果的目录
  4.     evaluation_strategy="epoch",     # 每个epoch进行一次评估
  5.     per_device_train_batch_size=4,   # 每个设备的训练批量大小
  6.     per_device_eval_batch_size=4,    # 每个设备的评估批量大小
  7.     num_train_epochs=3,              # 训练的epoch数量
  8.     save_strategy="epoch",           # 保存策略
  9.     logging_dir='./logs',            # 日志日志目录
  10.     logging_steps=10,                # 日志日志记录的步数
  11.     no_cuda=False,                   # 强制使用CPU
  12.     learning_rate=5e-5,              # 调整学习率
  13.     gradient_accumulation_steps=8,   # 梯度累
复制代码
初始化trainer并开始练习

  1. trainer = Trainer(
  2.     model=model,
  3.     args=training_args,
  4.     train_dataset=tokenized_datasets['train'],
  5.     eval_dataset=tokenized_datasets['validation'],
  6. )
  7. trainer.train()
复制代码
评估生存模子

  1. results = trainer.evaluate(eval_dataset=tokenized_datasets['validation'])
  2. print(f"Validation Results: {results}")
  3. model.save_pretrained('./trained_model')
  4. tokenizer.save_pretrained('./trained_model')
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!qidao123.com:ToB企服之家,中国第一个企服评测及软件市场,开放入驻,技术点评得现金
回复

使用道具 举报

登录后关闭弹窗

登录参与点评抽奖  加入IT实名职场社区
去登录
快速回复 返回顶部 返回列表