from transformers import Qwen2Config import inspect import math import os import warnings from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers import PretrainedConfig from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) import numpy as np from transformers import Qwen2Config from transformers import Qwen2ForCausalLM import inspect import math import os import warnings from typing import List, Optional, Tuple, Union from tqdm import tqdm, trange import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) import numpy as np import torch import os import argparse import json from tqdm import tqdm from typing import cast, List, Union, Tuple from transformers import AutoTokenizer, AutoModel # pylint: disable=C0413 from peft import LoraConfig, get_peft_model, TaskType import time import torch.nn.functional as F import sys import time import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from tqdm import tqdm, trange from collections import defaultdict from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig import torch.distributed as dist from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint import sys import torch import torch.nn as nn import torch.nn.functional as F import math import re # PMA部分 post_normal class MAB_POST(nn.Module): def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): super(MAB_POST, self).__init__() self.dim_V = dim_V self.num_heads = num_heads self.fc_q = nn.Linear(dim_Q, dim_V) self.fc_k = nn.Linear(dim_K, dim_V) self.fc_v = nn.Linear(dim_K, dim_V) if ln: self.ln0 = nn.LayerNorm(dim_V) self.ln1 = nn.LayerNorm(dim_V) self.fc_o = nn.Linear(dim_V, dim_V) nn.init.xavier_uniform_(self.fc_q.weight) nn.init.xavier_uniform_(self.fc_k.weight) nn.init.xavier_uniform_(self.fc_v.weight) nn.init.xavier_uniform_(self.fc_o.weight) # Q(bs, 1, emb), pad_mask (bs, seq) Post-LN def forward(self, Q, K, pad_mask=None): Q_ = self.fc_q(Q) K_, V_ = self.fc_k(K), self.fc_v(K) dim_split = self.dim_V // self.num_heads Q_ = torch.cat(Q_.split(dim_split, 2), 0) # (bs* num_head, 1, emb) K_ = torch.cat(K_.split(dim_split, 2), 0) # (bs* num_head, seq, emb) V_ = torch.cat(V_.split(dim_split, 2), 0) pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1) # (bs*num_head, 1, seq) score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V) score = score.masked_fill(pad_mask == 0, -1e12) A = torch.softmax(score, 2) # (bs*num_head, 1, seq) A = A * pad_mask O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2) # (bs, 1, emb) O = Q + O # O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) O = O if getattr(self, 'ln0', None) is None else self.ln0(O) O = O + F.relu(self.fc_o(O)) O = O if getattr(self, 'ln1', None) is None else self.ln1(O) return O # PMA部分 pre_normal class MAB_PRE_NORMAL(nn.Module): def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): super(MAB_PRE_NORMAL, self).__init__() self.dim_V = dim_V self.num_heads = num_heads self.fc_q = nn.Linear(dim_Q, dim_V) self.fc_k = nn.Linear(dim_K, dim_V) self.fc_v = nn.Linear(dim_K, dim_V) if ln: self.ln_q = nn.LayerNorm(dim_V) self.ln_kv = nn.LayerNorm(dim_V) self.ln_o = nn.LayerNorm(dim_V) self.ln_final = nn.LayerNorm(dim_V) self.fc_o = nn.Linear(dim_V, dim_V) nn.init.xavier_uniform_(self.fc_q.weight) nn.init.xavier_uniform_(self.fc_k.weight) nn.init.xavier_uniform_(self.fc_v.weight) nn.init.xavier_uniform_(self.fc_o.weight) # pad_mask (bs, seq) Pre-LN 正常架构 def forward(self, Q, K, pad_mask=None): Q_ = Q if getattr(self, 'ln_q', None) is None else self.ln_q(Q) K_ = K if getattr(self, 'ln_kv', None) is None else self.ln_kv(K) Q_ = self.fc_q(Q_) K_, V_ = self.fc_k(K_), self.fc_v(K_) dim_split = self.dim_V // self.num_heads Q_ = torch.cat(Q_.split(dim_split, 2), 0) # (bs* num_head, 1, emb) K_ = torch.cat(K_.split(dim_split, 2), 0) # (bs* num_head, seq, emb) V_ = torch.cat(V_.split(dim_split, 2), 0) # (bs* num_head, seq, emb) pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1) # (bs*num_head, 1, seq) score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V) score = score.masked_fill(pad_mask == 0, -1e12) A = torch.softmax(score, 2) # (bs*num_head, 1, seq) A = A * pad_mask O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2) O = Q + O O_ = O if getattr(self, 'ln_o', None) is None else self.ln_o(O) # O的layernorm分支 O_ = O + F.relu(self.fc_o(O_)) return O_ if getattr(self, 'ln_final', None) is None else self.ln_final(O_) # PMA部分 pre_gptj class MAB_PRE_GPTJ(nn.Module): def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): super(MAB_PRE_GPTJ, self).__init__() self.dim_V = dim_V self.num_heads = num_heads self.fc_q = nn.Linear(dim_Q, dim_V) self.fc_k = nn.Linear(dim_K, dim_V) self.fc_v = nn.Linear(dim_K, dim_V) self.fc_o = nn.Linear(dim_V, dim_V) nn.init.xavier_uniform_(self.fc_q.weight) nn.init.xavier_uniform_(self.fc_k.weight) nn.init.xavier_uniform_(self.fc_v.weight) nn.init.xavier_uniform_(self.fc_o.weight) if ln: self.ln_q = nn.LayerNorm(dim_V) self.ln_kv = nn.LayerNorm(dim_V) self.ln_final = nn.LayerNorm(dim_V) # pad_mask (bs, seq) def forward(self, Q, K, pad_mask=None): # layernorm Q_ = Q if getattr(self, 'ln_q', None) is None else self.ln_q(Q) K_ = K if getattr(self, 'ln_kv', None) is None else self.ln_kv(K) Q1 = self.fc_q(Q_) K1, V1 = self.fc_k(K_), self.fc_v(K_) dim_split = self.dim_V // self.num_heads Q1 = torch.cat(Q1.split(dim_split, 2), 0) # (bs* num_head, 1, emb) K1 = torch.cat(K1.split(dim_split, 2), 0) # (bs* num_head, seq, emb) V1 = torch.cat(V1.split(dim_split, 2), 0) # (bs* num_head, seq, emb) pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1) # (bs*num_head, 1, seq) score = Q1.bmm(K1.transpose(1,2))/math.sqrt(self.dim_V) score = score.masked_fill(pad_mask == 0, -1e12) A = torch.softmax(score, 2) # (bs*num_head, 1, seq) A = A * pad_mask O1 = torch.cat(A.bmm(V1).split(Q.size(0), 0), 2) # (bs, 1, emb) O2 = F.relu(self.fc_o(Q_)) # (bs, 1, emb) O_final = Q + O1 + O2 return O_final if getattr(self, 'ln_final', None) is None else self.ln_final(O_final) class PMA(nn.Module): def __init__(self, dim, num_heads, num_seeds, ln=False, pma_mode=None): super(PMA, self).__init__() self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim)) nn.init.xavier_uniform_(self.S) if pma_mode == 'post_normal': self.mab = MAB_POST(dim, dim, dim, num_heads, ln=ln) elif pma_mode == 'pre_normal': self.mab = MAB_PRE_NORMAL(dim, dim, dim, num_heads, ln=ln) elif pma_mode == 'pre_gptj': self.mab = MAB_PRE_GPTJ(dim, dim, dim, num_heads, ln=ln) else: raise ValueError(f"Error, the pma_mode {pma_mode} is not implemented !") # X: (bs, seq, emb), pad_mask: (bs, seq) def forward(self, X, pad_mask): if self.S.dtype != torch.bfloat16: X = X.float() return self.mab(self.S.repeat(X.size(0), 1, 1), X, pad_mask) # 普通双向transformer encoder, post_normal class EncoderLayer_POST(nn.Module): def __init__(self, dim_V, num_heads, ln=False): super(EncoderLayer_POST, self).__init__() self.dim_V = dim_V self.num_heads = num_heads self.fc_q = nn.Linear(dim_V, dim_V) self.fc_k = nn.Linear(dim_V, dim_V) self.fc_v = nn.Linear(dim_V, dim_V) self.fc_o = nn.Linear(dim_V, dim_V) nn.init.xavier_uniform_(self.fc_q.weight) nn.init.xavier_uniform_(self.fc_k.weight) nn.init.xavier_uniform_(self.fc_v.weight) nn.init.xavier_uniform_(self.fc_o.weight) if ln: self.ln0 = nn.LayerNorm(dim_V) self.ln1 = nn.LayerNorm(dim_V) # Q:(bs, seq, emb), pad_mask:(bs, seq) def forward(self, Q, pad_mask=None): Q_, K_, V_ = self.fc_q(Q), self.fc_k(Q), self.fc_v(Q) dim_split = self.dim_V // self.num_heads Q_ = torch.cat(Q_.split(dim_split, 2), 0) # (bs* num_head, seq, emb) K_ = torch.cat(K_.split(dim_split, 2), 0) # (bs* num_head, seq, emb) V_ = torch.cat(V_.split(dim_split, 2), 0) # (bs* num_head, seq, emb) pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1) # (bs*num_head, seq, seq) score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V) score = score.masked_fill(pad_mask == 0, -1e12) A = torch.softmax(score, 2) # (bs*num_head, seq, seq) A = A * pad_mask # (bs*num_head, seq, seq) O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2) # (bs, seq, emb) O = Q + O O = O if getattr(self, 'ln0', None) is None else self.ln0(O) O = O + F.relu(self.fc_o(O)) O = O if getattr(self, 'ln1', None) is None else self.ln1(O) return O # 普通双向transformer encoder, pre LN norm class EncoderLayer_PRE_NORMAL(nn.Module): def __init__(self, dim_V, num_heads, ln=False): super(EncoderLayer_PRE_NORMAL, self).__init__() self.dim_V = dim_V self.num_heads = num_heads self.fc_q = nn.Linear(dim_V, dim_V) self.fc_k = nn.Linear(dim_V, dim_V) self.fc_v = nn.Linear(dim_V, dim_V) self.fc_o = nn.Linear(dim_V, dim_V) nn.init.xavier_uniform_(self.fc_q.weight) nn.init.xavier_uniform_(self.fc_k.weight) nn.init.xavier_uniform_(self.fc_v.weight) nn.init.xavier_uniform_(self.fc_o.weight) if ln: self.ln_qkv = nn.LayerNorm(dim_V) self.ln_o = nn.LayerNorm(dim_V) # Q:(bs, seq, emb), pad_mask:(bs, seq) def forward(self, Q, pad_mask=None): Q_ = Q if getattr(self, 'ln_qkv', None) is None else self.ln_qkv(Q) # layernorm Q_, K_, V_ = self.fc_q(Q_), self.fc_k(Q_), self.fc_v(Q_) dim_split = self.dim_V // self.num_heads Q_ = torch.cat(Q_.split(dim_split, 2), 0) # (bs* num_head, seq, emb) K_ = torch.cat(K_.split(dim_split, 2), 0) # (bs* num_head, seq, emb) V_ = torch.cat(V_.split(dim_split, 2), 0) # (bs* num_head, seq, emb) pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1) # (bs*num_head, seq, seq) score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V) score = score.masked_fill(pad_mask == 0, -1e12) A = torch.softmax(score, 2) # (bs*num_head, seq, seq) A = A * pad_mask O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2) O = Q + O O_ = O if getattr(self, 'ln_o', None) is None else self.ln_o(O) # O的layernorm分支 O_ = O + F.relu(self.fc_o(O_)) return O_ # 普通双向transformer encoder, pre LN gptj class EncoderLayer_PRE_GPTJ(nn.Module): def __init__(self, dim_V, num_heads, ln=False): super(EncoderLayer_PRE_GPTJ, self).__init__() self.dim_V = dim_V self.num_heads = num_heads self.fc_q = nn.Linear(dim_V, dim_V) self.fc_k = nn.Linear(dim_V, dim_V) self.fc_v = nn.Linear(dim_V, dim_V) self.fc_o = nn.Linear(dim_V, dim_V) nn.init.xavier_uniform_(self.fc_q.weight) nn.init.xavier_uniform_(self.fc_k.weight) nn.init.xavier_uniform_(self.fc_v.weight) nn.init.xavier_uniform_(self.fc_o.weight) if ln: self.ln_qkv = nn.LayerNorm(dim_V) # Q:(bs, seq, emb), pad_mask:(bs, seq) def forward(self, Q, pad_mask=None): Q_ = Q if getattr(self, 'ln_qkv', None) is None else self.ln_qkv(Q) # layernorm Q1, K1, V1 = self.fc_q(Q_), self.fc_k(Q_), self.fc_v(Q_) dim_split = self.dim_V // self.num_heads Q1 = torch.cat(Q1.split(dim_split, 2), 0) # (bs* num_head, seq, emb) K1 = torch.cat(K1.split(dim_split, 2), 0) # (bs* num_head, seq, emb) V1 = torch.cat(V1.split(dim_split, 2), 0) # (bs* num_head, seq, emb) pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1) # (bs*num_head, seq, seq) score = Q1.bmm(K1.transpose(1,2))/math.sqrt(self.dim_V) score = score.masked_fill(pad_mask == 0, -1e12) A = torch.softmax(score, 2) # (bs*num_head, seq, seq) A = A * pad_mask O1 = torch.cat(A.bmm(V1).split(Q.size(0), 0), 2) # (bs, seq, emb) O2 = F.relu(self.fc_o(Q_)) O_final = Q + O1 + O2 return O_final class Encoder(nn.Module): def __init__(self, emb_dim, num_heads, ln, encoder_mode, num_encoder_layers): super(Encoder, self).__init__() self.num_encoder_layers = num_encoder_layers if encoder_mode == 'post_normal': self.layers = nn.ModuleList([EncoderLayer_POST(dim_V=emb_dim, num_heads=num_heads, ln=ln) for _ in range(num_encoder_layers)]) elif encoder_mode == 'pre_normal': self.layers = nn.ModuleList([EncoderLayer_PRE_NORMAL(dim_V=emb_dim, num_heads=num_heads, ln=ln) for _ in range(num_encoder_layers)]) elif encoder_mode == 'pre_gptj': self.layers = nn.ModuleList([EncoderLayer_PRE_GPTJ(dim_V=emb_dim, num_heads=num_heads, ln=ln) for _ in range(num_encoder_layers)]) else: raise ValueError(f"Error, the encoder_mode {encoder_mode} is not implemented !") # X:(bs, seq, emb), mask: (bs, seq) def forward(self, X, mask): if self.num_encoder_layers == 0: return X if self.layers[0].fc_q.weight.dtype != torch.bfloat16: X = X.float() for layer in self.layers: X = layer(X, mask) return X class D2LLMConfig(PretrainedConfig): model_type = "qwen2" keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, vocab_size=151936, hidden_size=4096, intermediate_size=22016, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, hidden_act="silu", max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, use_sliding_window=False, sliding_window=4096, max_window_layers=28, attention_dropout=0.0, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.use_sliding_window = use_sliding_window self.sliding_window = sliding_window if use_sliding_window else None self.max_window_layers = max_window_layers # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.attention_dropout = attention_dropout super().__init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) class D2Coder(PreTrainedModel): def __init__(self, config): super().__init__(config) self.plm_model = Qwen2ForCausalLM(config) self.embedding_method = config.embedding_method self.inf_seq_length = config.inf_seq_length self.encoder_mode = config.encoder_mode self.num_encoder_layers = config.num_encoder_layers self.padding_side = config.padding_side self.keep_max_layer = config.keep_max_layer self.emb_dim = self.plm_model.model.embed_tokens.weight.size(1) self.num_heads = config.pma_num_heads self.ln = config.pma_ln self.norm = config.pma_norm self.pma_mode = config.pma_norm_mode self.encoder = Encoder(self.emb_dim, self.num_heads, self.ln, self.encoder_mode, self.num_encoder_layers) self.mha_pma = PMA(self.emb_dim, self.num_heads, 1, ln=self.ln, pma_mode=self.pma_mode) def forward(self, inputs_all, mode, args): # output_embeddings_a = self.get_sentence_embedding(self.embedding_method, **inputs_a) # output_embeddings_b = self.get_sentence_embedding(self.embedding_method, **inputs_b) # (bs, emb_size) bs = self.args.batch_size if mode == 'train': output_embeddings_all = self.get_sentence_embedding(self.embedding_method, **inputs_all).reshape(2+self.args.neg_K, bs, -1) # (2+K, bs, emb_size) # if self.to_compress: # output_embeddings_all = self.projector(output_embeddings_all) output_embeddings_hardneg = output_embeddings_all[2:] # (neg_K, bs, emb) hn_norm = torch.nn.functional.normalize(output_embeddings_hardneg, p=2, dim=-1) elif mode == 'eval': output_embeddings_all = self.get_sentence_embedding(self.embedding_method, **inputs_all).reshape(2, bs, -1) # (2, bs, emb_size) # if self.to_compress: # output_embeddings_all = self.projector(output_embeddings_all) else: raise ValueError('Error of mode value') output_embeddings_a = output_embeddings_all[0] # (bs, emb) output_embeddings_b = output_embeddings_all[1] # (bs, emb) a_norm = torch.nn.functional.normalize(output_embeddings_a, p=2, dim=-1) b_norm = torch.nn.functional.normalize(output_embeddings_b, p=2, dim=-1) b_cross_gpus = gather_across_devices(output_embeddings_b, args.global_rank, self.world_size) b_norm_cross_gpus = torch.nn.functional.normalize(b_cross_gpus, p=2, dim=-1) # () assert a_norm.size(0) == b_norm.size(0) bs = output_embeddings_a.size(0) # in-batch计算部分 output_in_batch_local_gpu = torch.matmul(a_norm, b_norm.t()) output_in_batch_global_gpu = torch.matmul(a_norm, b_norm_cross_gpus.t()) if mode == 'train': # hard neg计算部分 pos_neg_emb = torch.cat([b_norm.unsqueeze(0), hn_norm], dim=0) # (1+neg_K, bs, emb) output_hardneg_specific_task = torch.matmul(a_norm.unsqueeze(1), pos_neg_emb.permute(1,2,0)).squeeze() # (bs, 1+neg_K) # output_pos_hardneg_rep_specific_task = torch.cat([output_embeddings_a.unsqueeze(0).expand(pos_neg_emb.size(0),-1,-1), pos_neg_emb],dim=-1) elif mode == 'eval': output_hardneg_specific_task = None output_pos_hardneg_rep_specific_task = None return output_in_batch_local_gpu, output_in_batch_global_gpu, output_hardneg_specific_task # (bs, bs) (bs, world_size*bs), (bs, 1+neg_K) # return output_in_batch_specific_task, output_hardneg_specific_task, output_pos_hardneg_rep_specific_task def last_embedding(self, A, index): bs, seq, emb = A.size() res = A[torch.arange(bs), index, :] return res def mean_embedding(self, A, mask): bs, seq, emb = A.size() res = (A * (mask.unsqueeze(-1))).sum(1) / (mask.sum(1).unsqueeze(-1)) return res # A (bs, seq, emb_size), mask (bs, 1, seq) def weighted_embedding(self, A, mask): weights = (torch.arange(start=1, end=A.size(1) + 1).unsqueeze(0).unsqueeze(-1).expand(A.size()).float()).to(A.device) input_mask_expanded = (mask.squeeze(1).unsqueeze(-1).expand(A.size()).float()).to(A.device) sum_embedding = torch.sum(A * input_mask_expanded * weights, dim=1) sum_mask = torch.sum(input_mask_expanded * weights, dim=1) weighted_embedding = sum_embedding / sum_mask return weighted_embedding def pma_embedding(self, A, mask): res = self.mha_pma(A, mask).squeeze(1) return res def get_sentence_embedding(self, embedding_method, **inputs): outputs = self.plm_model(inputs['input_ids'], inputs['attention_mask'], output_hidden_states=True) if embedding_method == 'last': embedding = outputs.hidden_states[self.keep_max_layer] index = inputs['attention_mask'].sum(-1).long() - 1 res_embedding = self.last_embedding(embedding, index) elif embedding_method == 'mean': embedding = outputs.hidden_states[self.keep_max_layer] res_embedding = self.mean_embedding(embedding, inputs['attention_mask']) elif embedding_method == 'weighted': embedding = outputs.hidden_states[self.keep_max_layer] res_embedding = self.weighted_embedding(embedding, inputs['attention_mask']) elif embedding_method == 'pma': embedding = outputs.hidden_states[self.keep_max_layer] # Qwen.hidden_state: (33, bs, seq, emb) attention_mask = inputs['attention_mask'] embedding = self.encoder(embedding, attention_mask) res_embedding = self.pma_embedding(embedding, attention_mask) # embedding: (bs, seq, emb), inputs['attention_mask']: (bs, seq) else: logger.debug('Error, no {} way to obtain embbedings'.format(embedding_method)) if not self.norm: res_embedding = torch.nn.functional.normalize(res_embedding, p=2.0, dim=-1, eps=1e-12, out=None) return res_embedding def encode(self, tokenizer, sentences, batch_size=32, convert_to_numpy=True, convert_to_tensor=False, show_progress_bar=True, max_seq_length=None, **kwargs): if max_seq_length is None: max_seq_length = self.inf_seq_length input_is_string = False if isinstance(sentences, str) or not hasattr(sentences, "__len__"): sentences = [sentences] input_is_string = True all_embeddings = [] length_sorted_idx = np.argsort([-len(s) for s in sentences]) sentences_sorted = [sentences[idx] for idx in length_sorted_idx] # 大到小重排 with torch.no_grad(): for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): sentences_batch = sentences_sorted[start_index: start_index + batch_size] # Compute sentences embeddingsz with torch.no_grad(): inputs = tokenizer(sentences_batch, padding=True, truncation=True, max_length=max_seq_length, add_special_tokens=False, return_tensors='pt').to(self.plm_model.device) embeddings = self.get_sentence_embedding(self.embedding_method, **inputs) # if self.to_compress: # embeddings = self.projector(embeddings) embeddings = embeddings.detach() if convert_to_numpy: if embeddings.dtype == torch.bfloat16: embeddings = embeddings.cpu().to(torch.float32) else: embeddings = embeddings.cpu() all_embeddings.extend(embeddings) all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] if convert_to_tensor: all_embeddings = torch.stack(all_embeddings) elif convert_to_numpy: all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) if input_is_string: all_embeddings = all_embeddings[0] return all_embeddings