NLP:微调BERT进行文本分类

打印 上一主题 下一主题

主题 212|帖子 212|积分 636

本篇博客的重点在于BERT的使用。
transformers包版本:4.44.2
  1. 微调BERT进行文本分类

  这里我们使用stanford大学的SST2数据集来演示BERT模型的微调过程。SST-2数据集(Stanford Sentiment Treebank 2)是一个用于情绪分类的经典数据集,常用于自然语言处置惩罚(NLP)领域的情绪分析使命。


  • 第1步: 下载数据。其代码如下:
  1. import pandas as pd
  2. from transformers import BertTokenizer
  3. from datasets import DatasetDict, Dataset
  4. from torch.utils.data import DataLoader
  5. from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
  6. splits = {'train': 'data/train-00000-of-00001.parquet',
  7.           'validation': 'data/validation-00000-of-00001.parquet',
  8.           'test': 'data/test-00000-of-00001.parquet'}
  9. train = pd.read_parquet("hf://datasets/stanfordnlp/sst2/" + splits["train"])
  10. validation = pd.read_parquet("hf://datasets/stanfordnlp/sst2/" + splits["validation"])
  11. test = pd.read_parquet("hf://datasets/stanfordnlp/sst2/" + splits["test"])
  12. dataset = DatasetDict({'train': Dataset.from_pandas(train),
  13.                        'validation': Dataset.from_pandas(validation),
  14.                        'test': Dataset.from_pandas(test)})
复制代码
要注意一下,这里并没有使用datasets包从hugging face上直接下载数据集的方式来获取数据,这是由于使用load_datesets方法获取数据时仍旧会提示:NotImplementedError: Loading a dataset cached in a LocalFileSystem is not supported


  • 第2步: 构造训练集、验证集和测试集
SST2数据会集训练集(train)共67349条,验证集(validation)共872条,而测试集(test)共1821条。由于训练集数目较大微调会比较耗时,所以从这三个数据集分别抽取出了1000条、200条、200条进行后续的使命。具体代码如下:
  1. dataset['train'] = dataset['train'].shuffle(seed=42).select(range(1000))
  2. dataset['validation'] = dataset['validation'].shuffle(seed=42).select(range(200))
  3. dataset['test'] = dataset['test'].shuffle(seed=42).select(range(200))
  4. print(dataset)
复制代码
其输出效果如下:
  1. Dataset({
  2.     features: ['idx', 'sentence', 'label'],
  3.     num_rows: 1000
  4. })
  5. Dataset({
  6.     features: ['idx', 'sentence', 'label'],
  7.     num_rows: 200
  8. })
  9. Dataset({
  10.     features: ['idx', 'sentence', 'label'],
  11.     num_rows: 200
  12. })
复制代码


  • 第3步:从bert中提取嵌入
训练集、验证集及测试集天生后,接着需要将这些语料全都转化成embedding向量。具体代码如下:
  1. tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
  2. def tokenize_function(examples):
  3.     return tokenizer(examples['sentence'], padding='max_length', truncation=True)
  4. dataset =dataset.map(tokenize_function, batched=True)
  5. dataset=dataset.remove_columns(['sentence',"idx"])
  6. dataset=dataset.rename_column("label","labels")
  7. dataset.set_format("torch")
  8. train_dataset=dataset['train']
  9. eval_dataset=dataset['validation']
  10. test_dataset=dataset['test']
复制代码


  • 第4步:模型训练。 具体代码如下:
  1. model=BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
  2. training_args = TrainingArguments(
  3.     output_dir='results',         
  4.     per_device_train_batch_size=8,  
  5.     per_device_eval_batch_size=8,      
  6.     num_train_epochs=1,
  7. )
  8. trainer = Trainer(
  9.     model=model,                        
  10.     args=training_args,                 
  11.     train_dataset=train_dataset,        
  12.     eval_dataset=eval_dataset,
  13. )
  14. trainer.train()
  15. trainer.evaluate()
  16. trainer.save_model("results")
复制代码
关于上述代码,有以下几点需要说明:


  • 训练模型的选择: tranformers库中有多个分类模型,其中BertForSequenceClassification类适用于序列分类使命,好比情绪分析和文本分类;而BertForTokenClassification类适用于token级的分类使命,好比命名实体识别。
  • TrainingArguments方法中的主要参数及其作用如下表所示:
参数名作用output_dir指定模型和训练日志生存的记录;num_train_epochs设置训练的周期数(即遍历整个训练数据集的次数,指的是整个训练集将被遍历多少次以进行训练);per_device_train_batch_size设置每个装备(如GPU)上的训练批次大小,训练批次是指在一次训练迭代中,模型同时处置惩罚的数据样本数量;per_device_eval_batch_size设置每个装备上的评估批次大小;logging_dir指定训练日志的生存目录;evaluation_strategy设置评估策略。可以是 ‘no’(不评估)、‘steps’(每隔肯定步数评估)或 ‘epoch’(每个周期评估);save_total_limit设置生存模型检查点的总数限制,超过限制的检查点会被删除;fp16启用半精度浮点数(FP16)训练,以淘汰显存使用并加速训练(需要支持 FP16 的硬件); 参考资料



  • BERT根本教程:Transformer大模型实战
  • https://blog.csdn.net/zoe9698/article/details/124579973

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

张裕

高级会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表