ToB企服应用市场:ToB评测及商务社交产业平台

标题: Sentence-BERT实现文本匹配【对比丧失函数】 [打印本页]

作者: 农妇山泉一亩田    时间: 2024-9-5 18:08
标题: Sentence-BERT实现文本匹配【对比丧失函数】
引言

照旧基于Sentence-BERT架构,大概说Bi-Encoder架构,但是本文使用的是参考2中提出的对比丧失函数。
架构


如上图,盘算两个句嵌入                                             u                                       \pmb u                  u和                                             v                                       \pmb v                  v​之间的间隔(1-余弦相似度),然后使用参考2中提出的对比丧失函数作为目标函数:
                                         L                            =                            y                            ×                                       1                               2                                      (                            distance                            (                                       u                                      ,                                       v                                      )                                       )                               2                                      +                            (                            1                            −                            y                            )                            ×                                       1                               2                                      {                            max                            ⁡                            (                            0                            ,                            m                            −                            distance                            (                                       u                                      ,                                       v                                      )                            )                                       }                               2                                                     \mathcal L= y\times \frac{1}{2} (\text{distance}(\pmb u,\pmb v))^2 + (1-y)\times \frac{1}{2} \{ \max(0, m - \text{distance}(\pmb u,\pmb v)) \}^2\\                     L=y×21​(distance(u,v))2+(1−y)×21​{max(0,m−distance(u,v))}2
这里的                                   y                              y                  y是真实标签,相似为1,不相似为0;                                   m                              m                  m​表示margin(隔断值),默以为0.5。
这里                                   m                              m                  m的意思是,如果                                             u                                       \pmb u                  u和                                             v                                       \pmb v                  v不相似(                                   y                         =                         0                              y=0                  y=0),那么它们之间的间隔只要充足大,大于即是隔断值0.5就好了。假设间隔为0.6,那么                                   max                         ⁡                         (                         0                         ,                         0.5                         −                         0.6                         )                         =                         0                              \max(0,0.5-0.6)=0                  max(0,0.5−0.6)=0,如果间隔不敷大(                                   0.2                              0.2                  0.2),那么                                   max                         ⁡                         (                         0                         ,                         0.5                         −                         0.2                         )                         =                         0.3                              \max(0,0.5-0.2)=0.3                  max(0,0.5−0.2)=0.3,就会产生丧失值。
整个公式的目的是拉近相似的文本对,推远不相似的文本对到一定程度就可以了。实现的时间                                   max                         ⁡                              \max                  max可以用relu来表示。
实现

实现采用雷同Huggingface的情势,每个文件夹下面有一种模型。分为modeling、arguments、trainer等差别的文件。差别的架构放置在差别的文件夹内。
modeling.py:
  1. from dataclasses import dataclass
  2. import torch
  3. from torch import Tensor, nn
  4. from transformers.file_utils import ModelOutput
  5. from transformers import (
  6.     AutoModel,
  7.     AutoTokenizer,
  8. )
  9. import numpy as np
  10. from tqdm.autonotebook import trange
  11. from typing import Optional
  12. from enum import Enum
  13. import torch.nn.functional as F
  14. # 定义了三种距离函数
  15. # 余弦相似度值越小表示越不相似,1减去它就变成了距离函数,越小(余弦越接近1)表示越相似。
  16. class SiameseDistanceMetric(Enum):
  17.     """The metric for the contrastive loss"""
  18.     EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
  19.     MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)
  20.     COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y)
  21. @dataclass
  22. class BiOutput(ModelOutput):
  23.     loss: Optional[Tensor] = None
  24.     scores: Optional[Tensor] = None
  25. class SentenceBert(nn.Module):
  26.     def __init__(
  27.         self,
  28.         model_name: str,
  29.         trust_remote_code: bool = True,
  30.         max_length: int = None,
  31.         margin: float = 0.5,
  32.         distance_metric=SiameseDistanceMetric.COSINE_DISTANCE,
  33.         pooling_mode: str = "mean",
  34.         normalize_embeddings: bool = False,
  35.     ) -> None:
  36.         super().__init__()
  37.         self.model_name = model_name
  38.         self.normalize_embeddings = normalize_embeddings
  39.         self.device = "cuda" if torch.cuda.is_available() else "cpu"
  40.         self.tokenizer = AutoTokenizer.from_pretrained(
  41.             model_name, trust_remote_code=trust_remote_code
  42.         )
  43.         self.model = AutoModel.from_pretrained(
  44.             model_name, trust_remote_code=trust_remote_code
  45.         ).to(self.device)
  46.         self.max_length = max_length
  47.         self.pooling_mode = pooling_mode
  48.         self.distance_metric = distance_metric
  49.         self.margin = margin
  50.     def sentence_embedding(self, last_hidden_state, attention_mask):
  51.         if self.pooling_mode == "mean":
  52.             attention_mask = attention_mask.unsqueeze(-1).float()
  53.             return torch.sum(last_hidden_state * attention_mask, dim=1) / torch.clamp(
  54.                 attention_mask.sum(1), min=1e-9
  55.             )
  56.         else:
  57.             # cls
  58.             return last_hidden_state[:, 0]
  59.     def encode(
  60.         self,
  61.         sentences: str | list[str],
  62.         batch_size: int = 64,
  63.         convert_to_tensor: bool = True,
  64.         show_progress_bar: bool = False,
  65.     ):
  66.         if isinstance(sentences, str):
  67.             sentences = [sentences]
  68.         all_embeddings = []
  69.         for start_index in trange(
  70.             0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar
  71.         ):
  72.             batch = sentences[start_index : start_index + batch_size]
  73.             features = self.tokenizer(
  74.                 batch,
  75.                 padding=True,
  76.                 truncation=True,
  77.                 return_tensors="pt",
  78.                 return_attention_mask=True,
  79.                 max_length=self.max_length,
  80.             ).to(self.device)
  81.             out_features = self.model(**features, return_dict=True)
  82.             embeddings = self.sentence_embedding(
  83.                 out_features.last_hidden_state, features["attention_mask"]
  84.             )
  85.             if not self.training:
  86.                 embeddings = embeddings.detach()
  87.             if self.normalize_embeddings:
  88.                 embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
  89.             if not convert_to_tensor:
  90.                 embeddings = embeddings.cpu()
  91.             all_embeddings.extend(embeddings)
  92.         if convert_to_tensor:
  93.             all_embeddings = torch.stack(all_embeddings)
  94.         else:
  95.             all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
  96.         return all_embeddings
  97.     def compute_loss(self, source_embed, target_embed, labels):
  98.         labels = torch.tensor(labels).float().to(self.device)
  99.                 # 计算距离
  100.         distances = self.distance_metric(source_embed, target_embed)
  101.                 # 实现损失函数
  102.         loss = 0.5 * (
  103.             labels * distances.pow(2)
  104.             + (1 - labels) * F.relu(self.margin - distances).pow(2)
  105.         )
  106.         return loss.mean()
  107.     def forward(self, source, target, labels) -> BiOutput:
  108.         """
  109.         Args:
  110.             source :
  111.             target :
  112.         """
  113.         source_embed = self.encode(source)
  114.         target_embed = self.encode(target)
  115.         loss = self.compute_loss(source_embed, target_embed, labels)
  116.         return BiOutput(loss, None)
  117.     def save_pretrained(self, output_dir: str):
  118.         state_dict = self.model.state_dict()
  119.         state_dict = type(state_dict)(
  120.             {k: v.clone().cpu().contiguous() for k, v in state_dict.items()}
  121.         )
  122.         self.model.save_pretrained(output_dir, state_dict=state_dict)
复制代码
整个模型的实现放到modeling.py文件中。
arguments.py:
  1. from dataclasses import dataclass, field
  2. from typing import Optional
  3. import os
  4. @dataclass
  5. class ModelArguments:
  6.     model_name_or_path: str = field(
  7.         metadata={
  8.             "help": "Path to pretrained model"
  9.         }
  10.     )
  11.     config_name: Optional[str] = field(
  12.         default=None,
  13.         metadata={
  14.             "help": "Pretrained config name or path if not the same as model_name"
  15.         },
  16.     )
  17.     tokenizer_name: Optional[str] = field(
  18.         default=None,
  19.         metadata={
  20.             "help": "Pretrained tokenizer name or path if not the same as model_name"
  21.         },
  22.     )
  23. @dataclass
  24. class DataArguments:
  25.     train_data_path: str = field(
  26.         default=None, metadata={"help": "Path to train corpus"}
  27.     )
  28.     eval_data_path: str = field(default=None, metadata={"help": "Path to eval corpus"})
  29.     max_length: int = field(
  30.         default=512,
  31.         metadata={
  32.             "help": "The maximum total input sequence length after tokenization for input text."
  33.         },
  34.     )
  35.     def __post_init__(self):
  36.         if not os.path.exists(self.train_data_path):
  37.             raise FileNotFoundError(
  38.                 f"cannot find file: {self.train_data_path}, please set a true path"
  39.             )
  40.         
  41.         if not os.path.exists(self.eval_data_path):
  42.             raise FileNotFoundError(
  43.                 f"cannot find file: {self.eval_data_path}, please set a true path"
  44.             )
复制代码
定义了模型和数据相关参数。
dataset.py:
  1. from torch.utils.data import Dataset
  2. from datasets import Dataset as dt
  3. import pandas as pd
  4. from utils import build_dataframe_from_csv
  5. class PairDataset(Dataset):
  6.     def __init__(self, data_path: str) -> None:
  7.         df = build_dataframe_from_csv(data_path)
  8.         self.dataset = dt.from_pandas(df, split="train")
  9.         self.total_len = len(self.dataset)
  10.     def __len__(self):
  11.         return self.total_len
  12.     def __getitem__(self, index) -> dict[str, str]:
  13.         query1 = self.dataset[index]["query1"]
  14.         query2 = self.dataset[index]["query2"]
  15.         label = self.dataset[index]["label"]
  16.         return {"query1": query1, "query2": query2, "label": label}
  17. class PairCollator:
  18.     def __call__(self, features) -> dict[str, list[str]]:
  19.         queries1 = []
  20.         queries2 = []
  21.         labels = []
  22.         for feature in features:
  23.             queries1.append(feature["query1"])
  24.             queries2.append(feature["query2"])
  25.             labels.append(feature["label"])
  26.         return {"source": queries1, "target": queries2, "labels": labels}
复制代码
数据集类思量了LCQMC数据集的格式,即成对的语句和一个数值标签。雷同:
  1. Hello.        Hi.        1
  2. Nice to see you.        Nice        0
复制代码
trainer.py:
  1. import torch
  2. from transformers.trainer import Trainer
  3. from typing import Optional
  4. import os
  5. import logging
  6. from modeling import SentenceBert
  7. TRAINING_ARGS_NAME = "training_args.bin"
  8. logger = logging.getLogger(__name__)
  9. class BiTrainer(Trainer):
  10.     def compute_loss(self, model: SentenceBert, inputs, return_outputs=False):
  11.         outputs = model(**inputs)
  12.         loss = outputs.loss
  13.         return (loss, outputs) if return_outputs else loss
  14.     def _save(self, output_dir: Optional[str] = None, state_dict=None):
  15.         # If we are executing this function, we are the process zero, so we don't check for that.
  16.         output_dir = output_dir if output_dir is not None else self.args.output_dir
  17.         os.makedirs(output_dir, exist_ok=True)
  18.         logger.info(f"Saving model checkpoint to {output_dir}")
  19.         self.model.save_pretrained(output_dir)
  20.         if self.tokenizer is not None:
  21.             self.tokenizer.save_pretrained(output_dir)
  22.         # Good practice: save your training arguments together with the trained model
  23.         torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
复制代码
继承




欢迎光临 ToB企服应用市场:ToB评测及商务社交产业平台 (https://dis.qidao123.com/) Powered by Discuz! X3.4