农妇山泉一亩田 发表于 2024-9-5 18:08:17

Sentence-BERT实现文本匹配【对比丧失函数】

引言

照旧基于Sentence-BERT架构,大概说Bi-Encoder架构,但是本文使用的是参考2中提出的对比丧失函数。
架构

https://img-blog.csdnimg.cn/img_convert/daae0f81b2b919b0bd43258ca840723d.png
如上图,盘算两个句嵌入                                             u                                       \pmb u                  u和                                             v                                       \pmb v                  v​之间的间隔(1-余弦相似度),然后使用参考2中提出的对比丧失函数作为目标函数:
                                       L                            =                            y                            ×                                       1                               2                                    (                            distance                            (                                       u                                    ,                                       v                                    )                                       )                               2                                    +                            (                            1                            −                            y                            )                            ×                                       1                               2                                    {                            max                            ⁡                            (                            0                            ,                            m                            −                            distance                            (                                       u                                    ,                                       v                                    )                            )                                       }                               2                                                   \mathcal L= y\times \frac{1}{2} (\text{distance}(\pmb u,\pmb v))^2 + (1-y)\times \frac{1}{2} \{ \max(0, m - \text{distance}(\pmb u,\pmb v)) \}^2\\                     L=y×21​(distance(u,v))2+(1−y)×21​{max(0,m−distance(u,v))}2
这里的                                 y                              y                  y是真实标签,相似为1,不相似为0;                                 m                              m                  m​表示margin(隔断值),默以为0.5。
这里                                 m                              m                  m的意思是,如果                                             u                                       \pmb u                  u和                                             v                                       \pmb v                  v不相似(                                 y                         =                         0                              y=0                  y=0),那么它们之间的间隔只要充足大,大于即是隔断值0.5就好了。假设间隔为0.6,那么                                 max                         ⁡                         (                         0                         ,                         0.5                         −                         0.6                         )                         =                         0                              \max(0,0.5-0.6)=0                  max(0,0.5−0.6)=0,如果间隔不敷大(                                 0.2                              0.2                  0.2),那么                                 max                         ⁡                         (                         0                         ,                         0.5                         −                         0.2                         )                         =                         0.3                              \max(0,0.5-0.2)=0.3                  max(0,0.5−0.2)=0.3,就会产生丧失值。
整个公式的目的是拉近相似的文本对,推远不相似的文本对到一定程度就可以了。实现的时间                                 max                         ⁡                              \max                  max可以用relu来表示。
实现

实现采用雷同Huggingface的情势,每个文件夹下面有一种模型。分为modeling、arguments、trainer等差别的文件。差别的架构放置在差别的文件夹内。
modeling.py:
from dataclasses import dataclass

import torch
from torch import Tensor, nn

from transformers.file_utils import ModelOutput

from transformers import (
    AutoModel,
    AutoTokenizer,
)

import numpy as np
from tqdm.autonotebook import trange
from typing import Optional

from enum import Enum
import torch.nn.functional as F

# 定义了三种距离函数
# 余弦相似度值越小表示越不相似,1减去它就变成了距离函数,越小(余弦越接近1)表示越相似。
class SiameseDistanceMetric(Enum):
    """The metric for the contrastive loss"""

    EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
    MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)
    COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y)


@dataclass
class BiOutput(ModelOutput):
    loss: Optional = None
    scores: Optional = None


class SentenceBert(nn.Module):
    def __init__(
      self,
      model_name: str,
      trust_remote_code: bool = True,
      max_length: int = None,
      margin: float = 0.5,
      distance_metric=SiameseDistanceMetric.COSINE_DISTANCE,
      pooling_mode: str = "mean",
      normalize_embeddings: bool = False,
    ) -> None:
      super().__init__()
      self.model_name = model_name
      self.normalize_embeddings = normalize_embeddings

      self.device = "cuda" if torch.cuda.is_available() else "cpu"

      self.tokenizer = AutoTokenizer.from_pretrained(
            model_name, trust_remote_code=trust_remote_code
      )
      self.model = AutoModel.from_pretrained(
            model_name, trust_remote_code=trust_remote_code
      ).to(self.device)

      self.max_length = max_length
      self.pooling_mode = pooling_mode

      self.distance_metric = distance_metric
      self.margin = margin

    def sentence_embedding(self, last_hidden_state, attention_mask):
      if self.pooling_mode == "mean":
            attention_mask = attention_mask.unsqueeze(-1).float()
            return torch.sum(last_hidden_state * attention_mask, dim=1) / torch.clamp(
                attention_mask.sum(1), min=1e-9
            )
      else:
            # cls
            return last_hidden_state[:, 0]

    def encode(
      self,
      sentences: str | list,
      batch_size: int = 64,
      convert_to_tensor: bool = True,
      show_progress_bar: bool = False,
    ):
      if isinstance(sentences, str):
            sentences =

      all_embeddings = []

      for start_index in trange(
            0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar
      ):
            batch = sentences

            features = self.tokenizer(
                batch,
                padding=True,
                truncation=True,
                return_tensors="pt",
                return_attention_mask=True,
                max_length=self.max_length,
            ).to(self.device)

            out_features = self.model(**features, return_dict=True)
            embeddings = self.sentence_embedding(
                out_features.last_hidden_state, features["attention_mask"]
            )
            if not self.training:
                embeddings = embeddings.detach()

            if self.normalize_embeddings:
                embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

            if not convert_to_tensor:
                embeddings = embeddings.cpu()

            all_embeddings.extend(embeddings)

      if convert_to_tensor:
            all_embeddings = torch.stack(all_embeddings)
      else:
            all_embeddings = np.asarray()

      return all_embeddings

    def compute_loss(self, source_embed, target_embed, labels):
      labels = torch.tensor(labels).float().to(self.device)
                # 计算距离
      distances = self.distance_metric(source_embed, target_embed)
                # 实现损失函数
      loss = 0.5 * (
            labels * distances.pow(2)
            + (1 - labels) * F.relu(self.margin - distances).pow(2)
      )
      return loss.mean()

    def forward(self, source, target, labels) -> BiOutput:
      """
      Args:
            source :
            target :
      """
      source_embed = self.encode(source)
      target_embed = self.encode(target)

      loss = self.compute_loss(source_embed, target_embed, labels)
      return BiOutput(loss, None)

    def save_pretrained(self, output_dir: str):
      state_dict = self.model.state_dict()
      state_dict = type(state_dict)(
            {k: v.clone().cpu().contiguous() for k, v in state_dict.items()}
      )
      self.model.save_pretrained(output_dir, state_dict=state_dict)

整个模型的实现放到modeling.py文件中。
arguments.py:
from dataclasses import dataclass, field
from typing import Optional

import os


@dataclass
class ModelArguments:
    model_name_or_path: str = field(
      metadata={
            "help": "Path to pretrained model"
      }
    )
    config_name: Optional = field(
      default=None,
      metadata={
            "help": "Pretrained config name or path if not the same as model_name"
      },
    )
    tokenizer_name: Optional = field(
      default=None,
      metadata={
            "help": "Pretrained tokenizer name or path if not the same as model_name"
      },
    )


@dataclass
class DataArguments:
    train_data_path: str = field(
      default=None, metadata={"help": "Path to train corpus"}
    )
    eval_data_path: str = field(default=None, metadata={"help": "Path to eval corpus"})
    max_length: int = field(
      default=512,
      metadata={
            "help": "The maximum total input sequence length after tokenization for input text."
      },
    )

    def __post_init__(self):
      if not os.path.exists(self.train_data_path):
            raise FileNotFoundError(
                f"cannot find file: {self.train_data_path}, please set a true path"
            )
      
      if not os.path.exists(self.eval_data_path):
            raise FileNotFoundError(
                f"cannot find file: {self.eval_data_path}, please set a true path"
            )

定义了模型和数据相关参数。
dataset.py:
from torch.utils.data import Dataset
from datasets import Dataset as dt
import pandas as pd

from utils import build_dataframe_from_csv


class PairDataset(Dataset):
    def __init__(self, data_path: str) -> None:

      df = build_dataframe_from_csv(data_path)
      self.dataset = dt.from_pandas(df, split="train")

      self.total_len = len(self.dataset)

    def __len__(self):
      return self.total_len

    def __getitem__(self, index) -> dict:
      query1 = self.dataset["query1"]
      query2 = self.dataset["query2"]
      label = self.dataset["label"]
      return {"query1": query1, "query2": query2, "label": label}


class PairCollator:
    def __call__(self, features) -> dict]:
      queries1 = []
      queries2 = []
      labels = []

      for feature in features:
            queries1.append(feature["query1"])
            queries2.append(feature["query2"])
            labels.append(feature["label"])

      return {"source": queries1, "target": queries2, "labels": labels}

数据集类思量了LCQMC数据集的格式,即成对的语句和一个数值标签。雷同:
Hello.        Hi.        1
Nice to see you.        Nice        0
trainer.py:
import torch
from transformers.trainer import Trainer

from typing import Optional
import os
import logging

from modeling import SentenceBert

TRAINING_ARGS_NAME = "training_args.bin"
logger = logging.getLogger(__name__)


class BiTrainer(Trainer):

    def compute_loss(self, model: SentenceBert, inputs, return_outputs=False):
      outputs = model(**inputs)
      loss = outputs.loss

      return (loss, outputs) if return_outputs else loss

    def _save(self, output_dir: Optional = None, state_dict=None):
      # If we are executing this function, we are the process zero, so we don't check for that.
      output_dir = output_dir if output_dir is not None else self.args.output_dir
      os.makedirs(output_dir, exist_ok=True)
      logger.info(f"Saving model checkpoint to {output_dir}")

      self.model.save_pretrained(output_dir)

      if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)

      # Good practice: save your training arguments together with the trained model
      torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

继承
页: [1]
查看完整版本: Sentence-BERT实现文本匹配【对比丧失函数】