transformers 阅读:BERT 模子

打印 上一主题 下一主题

主题 513|帖子 513|积分 1539

媒介

想深入理解 BERT 模子,在阅读 transformers 库同时纪录一下。
笔者小白,错误的地方请不吝指出。
Embedding

为了使 BERT 能处理大量卑鄙任务,它的输入可以明确表示单一句子或句子对,例如<标题,答案>。
   To make BERT handle a variety of down-stream tasks, our input representation is able to unambiguously represent both a single sentence and a pair of sentences (e.g., h Question, Answeri) in one token sequence
  因此 BERT 的 Embedding 分为三个部门:


  • Token Embeddings:对于分词效果进行嵌入。
  • Segement Embeddings:用于表示每个词地点句子,例如区分某个词是属于标题句子照旧属于答案句子。
  • Position Embeddings:位置嵌入。
在 transfoerms 中定义如下:
  1. class BertEmbeddings(nn.Module):  
  2.     """Construct the embeddings from word, position and token_type embeddings."""  
  3.     def __init__(self, config):  
  4.         super().__init__()  
  5.         self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)  
  6.         self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)  
  7.         self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)  
  8.         # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load  
  9.         # any TensorFlow checkpoint file  
  10.         self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)  
  11.         self.dropout = nn.Dropout(config.hidden_dropout_prob)  
  12.         # position_ids (1, len position emb) is contiguous in memory and exported when serialized  
  13.         self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")  
  14.         self.register_buffer(  
  15.         "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False  
  16.         )  
  17.         self.register_buffer(  
  18.         "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False  
  19. )
复制代码
值得注意的是,BERT 中采取可学习的位置嵌入,而不是 Transformer 中的计算编码。
下面是前向计算代码:
  1. def forward(
  2.     self,
  3.     input_ids: Optional[torch.LongTensor] = None,
  4.     token_type_ids: Optional[torch.LongTensor] = None,
  5.     position_ids: Optional[torch.LongTensor] = None,
  6.     inputs_embeds: Optional[torch.FloatTensor] = None,
  7.     past_key_values_length: int = 0,
  8. ) -> torch.Tensor:
  9.     if input_ids is not None:
  10.         input_shape = input_ids.size()
  11.     else:
  12.         input_shape = inputs_embeds.size()[:-1]
  13.     seq_length = input_shape[1]
  14.     if position_ids is None:
  15.         position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
  16.     # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  17.     # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  18.     # issue #5664
  19.     if token_type_ids is None:
  20.         if hasattr(self, "token_type_ids"):
  21.             buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  22.             buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
  23.             token_type_ids = buffered_token_type_ids_expanded
  24.         else:
  25.             token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  26.     if inputs_embeds is None:
  27.         inputs_embeds = self.word_embeddings(input_ids)
  28.     token_type_embeddings = self.token_type_embeddings(token_type_ids)
  29.     embeddings = inputs_embeds + token_type_embeddings
  30.     if self.position_embedding_type == "absolute":
  31.         position_embeddings = self.position_embeddings(position_ids)
  32.         embeddings += position_embeddings
  33.     embeddings = self.LayerNorm(embeddings)
  34.     embeddings = self.dropout(embeddings)
  35.     return embeddings
复制代码
下面介绍各个参数的含义和作用:


  • input_ids:当前词在词表的位置构成的列表。
  • token_type_ids:当前词对应的句子,所属第一句/第二句/Padding。
  • position_ids:当前词在句子中的位置构成的列表。
  • inputs_embeds:对 input_ids 进行嵌入的效果。
  • past_key_values_length:假如没有传入 position_ids 则从过去计算的地方向后自动取 seq_len 长度作为 position_ids
前向计算逻辑如下:

  • 根据 input_ids 计算 input_embeddings,假如提供 input_embeds 则不消计算。
  • 根据 token_type_ids 计算 token_type_embeddings。
  • 根据 position_ids 计算 position_embeddings。
  • 上面三个步骤的效果求和。
  • 对步骤4效果做一次 LayerNorm 和 Dropout 后输出。
BertSelfAttention

自注意力是 BERT 中的核心模块,其初始化代码如下:
  1. class BertSelfAttention(nn.Module):
  2.     def __init__(self, config, position_embedding_type=None):
  3.         super().__init__()
  4.         if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  5.             raise ValueError(
  6.                 f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  7.                 f"heads ({config.num_attention_heads})"
  8.             )
  9.         self.num_attention_heads = config.num_attention_heads
  10.         self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  11.         self.all_head_size = self.num_attention_heads * self.attention_head_size
  12.         self.query = nn.Linear(config.hidden_size, self.all_head_size)
  13.         self.key = nn.Linear(config.hidden_size, self.all_head_size)
  14.         self.value = nn.Linear(config.hidden_size, self.all_head_size)
  15.         self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  16.         self.position_embedding_type = position_embedding_type or getattr(
  17.             config, "position_embedding_type", "absolute"
  18.         )
  19.         if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  20.             self.max_position_embeddings = config.max_position_embeddings
  21.             self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  22.         self.is_decoder = config.is_decoder
复制代码
与 Transformer 基本一致,都会将 embedding 分为多块计算。虽然判断了 hidden_size 能否整除 num_attention_heads ,但是由于后面的设置,理论上仍然大概导致 hidden_size 与 all_head_size 大小不同。
在 BERT 中对应参数如下:
typehidden_sizenum_attention_headsbase76812large102416 在 base 和 large 中每个 head 的大小为 64。
下面是对张量进行维度变更,为了后面的计算。
  1. def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
  2.     new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  3.     x = x.view(new_x_shape)
  4.     return x.permute(0, 2, 1, 3)
复制代码
输入 x 的维度为 [bsz, seq_len, hidden_size],new_x_shape 变为 [bsz, seq_len, heads, head_size],然后交换 1 2 维度,变为 [bsz, heads, seq_len, head_size]。
前向计算代码如下:
  1. def forward(
  2.     self,
  3.     hidden_states: torch.Tensor,
  4.     attention_mask: Optional[torch.FloatTensor] = None,
  5.     head_mask: Optional[torch.FloatTensor] = None,
  6.     encoder_hidden_states: Optional[torch.FloatTensor] = None,
  7.     encoder_attention_mask: Optional[torch.FloatTensor] = None,
  8.     past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  9.     output_attentions: Optional[bool] = False,
  10. ) -> Tuple[torch.Tensor]:
  11.     mixed_query_layer = self.query(hidden_states)
  12.     # If this is instantiated as a cross-attention module, the keys
  13.     # and values come from an encoder; the attention mask needs to be
  14.     # such that the encoder's padding tokens are not attended to.
  15.     is_cross_attention = encoder_hidden_states is not None
  16.     if is_cross_attention and past_key_value is not None:
  17.         # reuse k,v, cross_attentions
  18.         key_layer = past_key_value[0]
  19.         value_layer = past_key_value[1]
  20.         attention_mask = encoder_attention_mask
  21.     elif is_cross_attention:
  22.         key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
  23.         value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
  24.         attention_mask = encoder_attention_mask
  25.     elif past_key_value is not None:
  26.         key_layer = self.transpose_for_scores(self.key(hidden_states))
  27.         value_layer = self.transpose_for_scores(self.value(hidden_states))
  28.         key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
  29.         value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
  30.     else:
  31.         key_layer = self.transpose_for_scores(self.key(hidden_states))
  32.         value_layer = self.transpose_for_scores(self.value(hidden_states))
  33.     query_layer = self.transpose_for_scores(mixed_query_layer)
  34.     use_cache = past_key_value is not None
  35.     if self.is_decoder:
  36.         # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  37.         # Further calls to cross_attention layer can then reuse all cross-attention
  38.         # key/value_states (first "if" case)
  39.         # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  40.         # all previous decoder key/value_states. Further calls to uni-directional self-attention
  41.         # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  42.         # if encoder bi-directional self-attention `past_key_value` is always `None`
  43.         past_key_value = (key_layer, value_layer)
  44.     # Take the dot product between "query" and "key" to get the raw attention scores.
  45.     attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  46.     if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  47.         query_length, key_length = query_layer.shape[2], key_layer.shape[2]
  48.         if use_cache:
  49.             position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
  50.                 -1, 1
  51.             )
  52.         else:
  53.             position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
  54.         position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
  55.         distance = position_ids_l - position_ids_r
  56.         positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
  57.         positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility
  58.         if self.position_embedding_type == "relative_key":
  59.             relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  60.             attention_scores = attention_scores + relative_position_scores
  61.         elif self.position_embedding_type == "relative_key_query":
  62.             relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  63.             relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
  64.             attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
  65.     attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  66.     if attention_mask is not None:
  67.         # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  68.         attention_scores = attention_scores + attention_mask
  69.     # Normalize the attention scores to probabilities.
  70.     attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  71.     # This is actually dropping out entire tokens to attend to, which might
  72.     # seem a bit unusual, but is taken from the original Transformer paper.
  73.     attention_probs = self.dropout(attention_probs)
  74.     # Mask heads if we want to
  75.     if head_mask is not None:
  76.         attention_probs = attention_probs * head_mask
  77.     context_layer = torch.matmul(attention_probs, value_layer)
  78.     context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  79.     new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  80.     context_layer = context_layer.view(new_context_layer_shape)
  81.     outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  82.     if self.is_decoder:
  83.         outputs = outputs + (past_key_value,)
  84.     return outputs
复制代码
首先判断是否要做交织注意力,判断原则就是是否输入 encoder_hidden_states。再联合是否输入 past_key_value 就形成了四种情况。

  • 计算交织注意力,传入过去的键值。则 K、V 均采取过去的键值,Mask 采取 encoder_attention_mask。
  • 计算交织注意力,没有传入过去的键值。则 K、V 通过 hidden_size 线性变更得到,Mask 采取 encoder_attention_mask。
  • 不计算交织注意力,传入过去的键值。则 K、V 由 hidden_size 线性变更之后,与过去的键值拼接而成,拼接维度 dim=2。
  • 不计算交织注意力,没有传入过去的键值。则 K、V 通过 hidden_size 线性变更得到。
无论哪种情况,Q 的都是由 hidden_size 线性变更得到。
然后就是计算 QKTQK^TQKT 得到 raw_attention_score。但是在进行 scale 之前,对位置编码种类进行特别操纵。
absolute
不进行任何操纵。
relative
分为 relative_key 和 relative_key_query。
这两者都会先计算 Q 和 K 的距离,然后对距离进行 embedding。
不同的是 relative_key 会将 Q 与上述 embedding 进行计算后与 raw_attention_score 相加。
relative_key_query 会将 QV 都与上述 embedding 进行计算,然后两者与 raw_attention_score 累加。
经过上面的操纵后,对 raw_attention_score 进行 scale 操纵。
操纵之后,与 Mask 计算采取加法而不是乘法,这是因为 Mask 的值是很大的负数而不是零,这种方式 Mask 更加“严实”。
然后就是正常的后续计算。
BertSelfOutput

多头注意力后的操纵:
  1. class BertSelfOutput(nn.Module):
  2.     def __init__(self, config):
  3.         super().__init__()
  4.         self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  5.         self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  6.         self.dropout = nn.Dropout(config.hidden_dropout_prob)
  7.     def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  8.         hidden_states = self.dense(hidden_states)
  9.         hidden_states = self.dropout(hidden_states)
  10.         hidden_states = self.LayerNorm(hidden_states + input_tensor)
  11.         return hidden_states
复制代码
必要注意的是,hidden_size 经过线性变更之后,先经过 dropoutdropoutdropout,然后与 input_tensor 进行残差毗连,之后进行 LayerNormLayerNormLayerNorm。
BertAttention

上面报告了 BERT 中的多头注意力层和注意力层之后的输出,这里就是对这两块进行一次封装。
  1. class BertAttention(nn.Module):
  2.     def __init__(self, config, position_embedding_type=None):
  3.         super().__init__()
  4.         self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
  5.         self.output = BertSelfOutput(config)
  6.         self.pruned_heads = set()
  7.     def prune_heads(self, heads):
  8.         if len(heads) == 0:
  9.             return
  10.         heads, index = find_pruneable_heads_and_indices(
  11.             heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  12.         )
  13.         # Prune linear layers
  14.         self.self.query = prune_linear_layer(self.self.query, index)
  15.         self.self.key = prune_linear_layer(self.self.key, index)
  16.         self.self.value = prune_linear_layer(self.self.value, index)
  17.         self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  18.         # Update hyper params and store pruned heads
  19.         self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  20.         self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  21.         self.pruned_heads = self.pruned_heads.union(heads)
复制代码
里面定义 prune_heads() 函数用于注意力头的剪枝。
此中 find_pruneable_heads_and_indices 用于找到可以剪枝的注意力头。返回必要剪掉的 heads 和保存的维度下标。
  1. def find_pruneable_heads_and_indices(
  2.     heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
  3. ) -> Tuple[Set[int], torch.LongTensor]:
  4.     """
  5.     Finds the heads and their indices taking `already_pruned_heads` into account.
  6.     Args:
  7.         heads (`List[int]`): List of the indices of heads to prune.
  8.         n_heads (`int`): The number of heads in the model.
  9.         head_size (`int`): The size of each head.
  10.         already_pruned_heads (`Set[int]`): A set of already pruned heads.
  11.     Returns:
  12.         `Tuple[Set[int], torch.LongTensor]`: A tuple with the indices of heads to prune taking `already_pruned_heads`
  13.         into account and the indices of rows/columns to keep in the layer weight.
  14.     """
复制代码
prune_linear_layer 函数用于具体剪枝注意力头。会按照 index 保存维度。
  1. def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear:
  2.     """
  3.     Prune a linear layer to keep only entries in index.
  4.     Used to remove heads.
  5.     Args:
  6.         layer (`torch.nn.Linear`): The layer to prune.
  7.         index (`torch.LongTensor`): The indices to keep in the layer.
  8.         dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices.
  9.     Returns:
  10.         `torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`.
  11.     """
复制代码
前向计算代码如下:
  1. def forward(
  2.     self,
  3.     hidden_states: torch.Tensor,
  4.     attention_mask: Optional[torch.FloatTensor] = None,
  5.     head_mask: Optional[torch.FloatTensor] = None,
  6.     encoder_hidden_states: Optional[torch.FloatTensor] = None,
  7.     encoder_attention_mask: Optional[torch.FloatTensor] = None,
  8.     past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  9.     output_attentions: Optional[bool] = False,
  10. ) -> Tuple[torch.Tensor]:
  11.     self_outputs = self.self(
  12.         hidden_states,
  13.         attention_mask,
  14.         head_mask,
  15.         encoder_hidden_states,
  16.         encoder_attention_mask,
  17.         past_key_value,
  18.         output_attentions,
  19.     )
  20.     attention_output = self.output(self_outputs[0], hidden_states)
  21.     outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
  22.     return outputs
复制代码
outputs 大概包含(attention, all_attention, past_value_key)
BertIntermediate

Attention 之后加入全毗连层和激活函数,比较简朴。
  1. class BertIntermediate(nn.Module):
  2.     def __init__(self, config):
  3.         super().__init__()
  4.         self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  5.         if isinstance(config.hidden_act, str):
  6.             self.intermediate_act_fn = ACT2FN[config.hidden_act]
  7.         else:
  8.             self.intermediate_act_fn = config.hidden_act
  9.     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  10.         hidden_states = self.dense(hidden_states)
  11.         hidden_states = self.intermediate_act_fn(hidden_states)
  12.         return hidden_states
复制代码
BertOutput

经过中间层后,又是一个全毗连层 + dropout + 残差 + 层归一化。和 BertSelfOutput 架构相同。
  1. class BertOutput(nn.Module):
  2.     def __init__(self, config):
  3.         super().__init__()
  4.         self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  5.         self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  6.         self.dropout = nn.Dropout(config.hidden_dropout_prob)
  7.     def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  8.         hidden_states = self.dense(hidden_states)
  9.         hidden_states = self.dropout(hidden_states)
  10.         hidden_states = self.LayerNorm(hidden_states + input_tensor)
  11.         return hidden_states
复制代码
BertLayer

BertLayer 是将 BertAttention、BertIntermediate 和 BertOutput 封装起来。
  1. class BertLayer(nn.Module):
  2.     def __init__(self, config):
  3.         super().__init__()
  4.         self.chunk_size_feed_forward = config.chunk_size_feed_forward
  5.         self.seq_len_dim = 1
  6.         self.attention = BertAttention(config)
  7.         self.is_decoder = config.is_decoder
  8.         self.add_cross_attention = config.add_cross_attention
  9.         if self.add_cross_attention:
  10.             if not self.is_decoder:
  11.                 raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
  12.             self.crossattention = BertAttention(config, position_embedding_type="absolute")
  13.         self.intermediate = BertIntermediate(config)
  14.         self.output = BertOutput(config)
复制代码
这里注意的是,假如加入交织注意力,必须作为 decoder。
前向计算代码如下:
  1. def forward(
  2.         self,
  3.         hidden_states: torch.Tensor,
  4.         attention_mask: Optional[torch.FloatTensor] = None,
  5.         head_mask: Optional[torch.FloatTensor] = None,
  6.         encoder_hidden_states: Optional[torch.FloatTensor] = None,
  7.         encoder_attention_mask: Optional[torch.FloatTensor] = None,
  8.         past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  9.         output_attentions: Optional[bool] = False,
  10.     ) -> Tuple[torch.Tensor]:
  11.         # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  12.         self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
  13.         self_attention_outputs = self.attention(
  14.             hidden_states,
  15.             attention_mask,
  16.             head_mask,
  17.             output_attentions=output_attentions,
  18.             past_key_value=self_attn_past_key_value,
  19.         )
  20.         attention_output = self_attention_outputs[0]
  21.         # if decoder, the last output is tuple of self-attn cache
  22.         if self.is_decoder:
  23.             outputs = self_attention_outputs[1:-1]
  24.             present_key_value = self_attention_outputs[-1]
  25.         else:
  26.             outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights
  27.         cross_attn_present_key_value = None
  28.         if self.is_decoder and encoder_hidden_states is not None:
  29.             if not hasattr(self, "crossattention"):
  30.                 raise ValueError(
  31.                     f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
  32.                     " by setting `config.add_cross_attention=True`"
  33.                 )
  34.             # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
  35.             cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
  36.             cross_attention_outputs = self.crossattention(
  37.                 attention_output,
  38.                 attention_mask,
  39.                 head_mask,
  40.                 encoder_hidden_states,
  41.                 encoder_attention_mask,
  42.                 cross_attn_past_key_value,
  43.                 output_attentions,
  44.             )
  45.             attention_output = cross_attention_outputs[0]
  46.             outputs = outputs + cross_attention_outputs[1:-1]  # add cross attentions if we output attention weights
  47.             # add cross-attn cache to positions 3,4 of present_key_value tuple
  48.             cross_attn_present_key_value = cross_attention_outputs[-1]
  49.             present_key_value = present_key_value + cross_attn_present_key_value
  50.         layer_output = apply_chunking_to_forward(
  51.             self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  52.         )
  53.         outputs = (layer_output,) + outputs
  54.         # if decoder, return the attn key/values as the last output
  55.         if self.is_decoder:
  56.             outputs = outputs + (present_key_value,)
  57.         return outputs
  58.     def feed_forward_chunk(self, attention_output):
  59.         intermediate_output = self.intermediate(attention_output)
  60.         layer_output = self.output(intermediate_output, attention_output)
  61.         return layer_output
复制代码
基本逻辑:

  • 对 hidden_states 进行一次 Attention。
  • 假如是 decoder,将 attention_outputs 进行一次 CrossAttention。
  • 经过中间层和 Output 层。
必要注意的是,对于第三步,这里采取分块操纵来节省内存。
  1. layer_output = apply_chunking_to_forward(
  2.     self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  3. )
  4. def feed_forward_chunk(self, attention_output):
  5.     intermediate_output = self.intermediate(attention_output)
  6.     layer_output = self.output(intermediate_output, attention_output)
  7.     return layer_output
  8. def apply_chunking_to_forward(
  9.     forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
  10. ) -> torch.Tensor:
  11.     """
  12.     This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension
  13.     `chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory.
  14.     If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly
  15.     applying `forward_fn` to `input_tensors`.
  16.     Args:
  17.         forward_fn (`Callable[..., torch.Tensor]`):
  18.             The forward function of the model.
  19.         chunk_size (`int`):
  20.             The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`.
  21.         chunk_dim (`int`):
  22.             The dimension over which the `input_tensors` should be chunked.
  23.         input_tensors (`Tuple[torch.Tensor]`):
  24.             The input tensors of `forward_fn` which will be chunked
  25.     Returns:
  26.         `torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`.
复制代码
apply_chunking_to_forward 函数就是将输入的 tensor 在指定的 chunk_dim 上分为若干个大小为 chunk_size 的块,然后对每块进行前向计算。
BERT 中指定的维度是 seq_len_dim,也就是将一个句子分为若干指定大小的块,分别进行计算。
BertEncoder

有了前面的 BertLayer,BertEncoder 就是若干 BertLayer 堆叠而成。
  1. class BertEncoder(nn.Module):
  2.     def __init__(self, config):
  3.         super().__init__()
  4.         self.config = config
  5.         self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
  6.         self.gradient_checkpointing = False
复制代码
前向计算也就是输入经过若干 BertLayer。假如设置了梯度查抄点,而且处于训练状态,BERT 会采取 torch.utils.checkpoint.checkpoint() 来节省内存。
  1. if self.gradient_checkpointing and self.training:
  2.     def create_custom_forward(module):
  3.         def custom_forward(*inputs):
  4.             return module(*inputs, past_key_value, output_attentions)
  5.         return custom_forward
  6.     layer_outputs = torch.utils.checkpoint.checkpoint(
  7.         create_custom_forward(layer_module),
  8.         hidden_states,
  9.         attention_mask,
  10.         layer_head_mask,
  11.         encoder_hidden_states,
  12.         encoder_attention_mask,
  13.     )
  14. def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
  15.     r"""Checkpoint a model or part of the model
  16.     Checkpointing works by trading compute for memory. Rather than storing all
  17.     intermediate activations of the entire computation graph for computing
  18.     backward, the checkpointed part does **not** save intermediate activations,
  19.     and instead recomputes them in backward pass. It can be applied on any part
  20.     of a model.
复制代码
这是一种时间换空间的思想。
BertPooler

  1. class BertPooler(nn.Module):
  2.     def __init__(self, config):
  3.         super().__init__()
  4.         self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  5.         self.activation = nn.Tanh()
  6.     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  7.         # We "pool" the model by simply taking the hidden state corresponding
  8.         # to the first token.
  9.         first_token_tensor = hidden_states[:, 0]
  10.         pooled_output = self.dense(first_token_tensor)
  11.         pooled_output = self.activation(pooled_output)
  12.         return pooled_output
复制代码
BertPooler 仅仅获取 hidden_states 在 seq_len 维度上的第一个向量,然后经过线性变更后传入激活函数。
Summary

BERT 是由若干 BertLayer 堆叠而成,可以在最后一层加入不同的线性层,以适应不同的卑鄙任务。
BertLayer 是由 BertAttention、CrossAttention(可选)、BertIntermediate 和 BertOutput 堆叠而成。


  • BertAttention:由自注意力层和输出层构成

    • 自注意力层:Mask 采取加法,被遮罩的地方为较大负数
    • 输出层:依次经过 线性层 dropout 残差 LayerNorm

  • BertIntermediate:依次经过 线性层 激活函数
  • BertOutput:依次经过 线性层 dropout 残差 LayerNorm
那么,我们该怎样学习大模子?

作为一名热心肠的互联网老兵,我决定把宝贵的AI知识分享给大家。 至于能学习到多少就看你的学习毅力和能力了 。我已将重要的AI大模子资料包罗AI大模子入门学习头脑导图、精品AI大模子学习书籍手册、视频教程、实战学习等录播视频免费分享出来。
一、大模子全套的学习门路

学习大型人工智能模子,如GPT-3、BERT或任何其他先进的神经网络模子,必要体系的方法和持续的努力。既然要体系的学习大模子,那么学习门路是必不可少的,下面的这份门路能资助你快速梳理知识,形成本身的体系。
L1级别:AI大模子期间的华丽登场

L2级别:AI大模子API应用开辟工程

L3级别:大模子应用架构进阶实践

L4级别:大模子微调与私有化部署

一样寻常掌握到第四个级别,市场上大多数岗位都是可以胜任,但要还不是天花板,天花板级别要求更加严格,对于算法和实战是非常苛刻的。发起普通人掌握到L4级别即可。
以上的AI大模子学习门路,不知道为什么发出来就有点糊,高清版可以微信扫描下方CSDN官方认证二维码免费领取【包管100%免费】
二、640套AI大模子报告合集

这套包含640份报告的合集,涵盖了AI大模子的理论研究、技能实现、行业应用等多个方面。无论您是科研人员、工程师,照旧对AI大模子感兴趣的爱好者,这套报告合集都将为您提供宝贵的信息和启示。

三、大模子经典PDF籍

随着人工智能技能的飞速发展,AI大模子已经成为了当今科技领域的一大热点。这些大型预训练模子,如GPT-3、BERT、XLNet等,以其强盛的语言理解和天生能力,正在改变我们对人工智能的认识。 那以下这些PDF籍就是非常不错的学习资源。

四、AI大模子商业化落地方案


作为普通人,入局大模子期间必要持续学习和实践,不断进步本身的技能和认知水平,同时也必要有责任感和伦理意识,为人工智能的健康发展贡献力量。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

何小豆儿在此

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表