|
import inspect |
|
import math |
|
import os |
|
import warnings |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
from transformers import PretrainedConfig |
|
|
|
from transformers.activations import ACT2FN |
|
from transformers.cache_utils import Cache, DynamicCache |
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa |
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.utils import ( |
|
add_start_docstrings, |
|
add_start_docstrings_to_model_forward, |
|
is_flash_attn_2_available, |
|
is_flash_attn_greater_or_equal_2_10, |
|
logging, |
|
replace_return_docstrings, |
|
) |
|
import numpy as np |
|
from transformers import Phi3ForCausalLM |
|
import inspect |
|
import math |
|
import os |
|
import warnings |
|
from typing import List, Optional, Tuple, Union |
|
from tqdm import tqdm, trange |
|
import torch |
|
import torch.nn.functional as F |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
|
from transformers.activations import ACT2FN |
|
from transformers.cache_utils import Cache, DynamicCache |
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa |
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.utils import ( |
|
add_start_docstrings, |
|
add_start_docstrings_to_model_forward, |
|
is_flash_attn_2_available, |
|
is_flash_attn_greater_or_equal_2_10, |
|
logging, |
|
replace_return_docstrings, |
|
) |
|
import numpy as np |
|
import torch |
|
import os |
|
import argparse |
|
import json |
|
from tqdm import tqdm |
|
from typing import cast, List, Union, Tuple |
|
from transformers import AutoTokenizer, AutoModel |
|
from peft import LoraConfig, get_peft_model, TaskType |
|
import time |
|
import torch.nn.functional as F |
|
import sys |
|
import time |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from tqdm import tqdm, trange |
|
from collections import defaultdict |
|
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig |
|
import torch.distributed as dist |
|
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint |
|
import sys |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import math |
|
import re |
|
from CodeFuseCGESmallConfig import CodeFuseCGESmallConfig |
|
|
|
|
|
class MAB_POST(nn.Module): |
|
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): |
|
super(MAB_POST, self).__init__() |
|
self.dim_V = dim_V |
|
self.num_heads = num_heads |
|
self.fc_q = nn.Linear(dim_Q, dim_V) |
|
self.fc_k = nn.Linear(dim_K, dim_V) |
|
self.fc_v = nn.Linear(dim_K, dim_V) |
|
if ln: |
|
self.ln0 = nn.LayerNorm(dim_V) |
|
self.ln1 = nn.LayerNorm(dim_V) |
|
self.fc_o = nn.Linear(dim_V, dim_V) |
|
nn.init.xavier_uniform_(self.fc_q.weight) |
|
nn.init.xavier_uniform_(self.fc_k.weight) |
|
nn.init.xavier_uniform_(self.fc_v.weight) |
|
nn.init.xavier_uniform_(self.fc_o.weight) |
|
|
|
def forward(self, Q, K, pad_mask=None): |
|
|
|
Q_ = self.fc_q(Q) |
|
K_, V_ = self.fc_k(K), self.fc_v(K) |
|
|
|
dim_split = self.dim_V // self.num_heads |
|
Q_ = torch.cat(Q_.split(dim_split, 2), 0) |
|
K_ = torch.cat(K_.split(dim_split, 2), 0) |
|
V_ = torch.cat(V_.split(dim_split, 2), 0) |
|
|
|
pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1) |
|
score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V) |
|
score = score.masked_fill(pad_mask == 0, -1e12) |
|
A = torch.softmax(score, 2) |
|
A = A * pad_mask |
|
O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2) |
|
O = Q + O |
|
O = O if getattr(self, 'ln0', None) is None else self.ln0(O) |
|
O = O + F.relu(self.fc_o(O)) |
|
O = O if getattr(self, 'ln1', None) is None else self.ln1(O) |
|
return O |
|
|
|
|
|
class PMA(nn.Module): |
|
def __init__(self, dim, compress_dim, num_heads, num_seeds, ln=False, pma_mode=None): |
|
super(PMA, self).__init__() |
|
self.S = nn.Parameter(torch.Tensor(1, num_seeds, compress_dim)) |
|
nn.init.xavier_uniform_(self.S) |
|
if pma_mode == 'post_normal': |
|
self.mab = MAB_POST(compress_dim, dim, compress_dim, num_heads, ln=ln) |
|
elif pma_mode == 'pre_normal': |
|
self.mab = MAB_PRE_NORMAL(compress_dim, dim, compress_dim, num_heads, ln=ln) |
|
elif pma_mode == 'pre_gptj': |
|
self.mab = MAB_PRE_GPTJ(compress_dim, dim, compress_dim, num_heads, ln=ln) |
|
else: |
|
raise ValueError(f"Error, the pma_mode {pma_mode} is not implemented !") |
|
|
|
def forward(self, X, pad_mask): |
|
if self.S.dtype != torch.bfloat16: |
|
X = X.float() |
|
return self.mab(self.S.repeat(X.size(0), 1, 1), X, pad_mask) |
|
|
|
|
|
class CodeFuse_CGE_Small(PreTrainedModel): |
|
config_class = CodeFuseCGESmallConfig |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.plm_model = Phi3ForCausalLM(config) |
|
self.embedding_method = config.embedding_method |
|
self.inf_seq_length = config.inf_seq_length |
|
self.padding_side = config.padding_side |
|
|
|
self.keep_max_layer = config.keep_max_layer |
|
self.emb_dim = self.plm_model.model.embed_tokens.weight.size(1) |
|
self.num_heads = config.pma_num_heads |
|
self.ln = config.pma_ln |
|
self.norm = config.pma_norm |
|
self.compress_dim = config.compress_dim |
|
self.pma_mode = config.pma_norm_mode |
|
self.mha_pma = PMA(self.emb_dim, self.compress_dim, self.num_heads, 1, ln=self.ln, pma_mode=self.pma_mode) |
|
|
|
def last_embedding(self, A, index): |
|
bs, seq, emb = A.size() |
|
res = A[torch.arange(bs), index, :] |
|
return res |
|
|
|
def mean_embedding(self, A, mask): |
|
bs, seq, emb = A.size() |
|
res = (A * (mask.unsqueeze(-1))).sum(1) / (mask.sum(1).unsqueeze(-1)) |
|
return res |
|
|
|
def weighted_embedding(self, A, mask): |
|
weights = (torch.arange(start=1, end=A.size(1) + 1).unsqueeze(0).unsqueeze(-1).expand(A.size()).float()).to(A.device) |
|
input_mask_expanded = (mask.squeeze(1).unsqueeze(-1).expand(A.size()).float()).to(A.device) |
|
sum_embedding = torch.sum(A * input_mask_expanded * weights, dim=1) |
|
sum_mask = torch.sum(input_mask_expanded * weights, dim=1) |
|
weighted_embedding = sum_embedding / sum_mask |
|
return weighted_embedding |
|
|
|
def pma_embedding(self, A, mask): |
|
res = self.mha_pma(A, mask).squeeze(1) |
|
return res |
|
|
|
def get_sentence_embedding(self, embedding_method, **inputs): |
|
outputs = self.plm_model(inputs['input_ids'], inputs['attention_mask'], output_hidden_states=True) |
|
if embedding_method == 'last': |
|
embedding = outputs.hidden_states[self.keep_max_layer] |
|
index = inputs['attention_mask'].sum(-1).long() - 1 |
|
res_embedding = self.last_embedding(embedding, index) |
|
elif embedding_method == 'mean': |
|
embedding = outputs.hidden_states[self.keep_max_layer] |
|
res_embedding = self.mean_embedding(embedding, inputs['attention_mask']) |
|
elif embedding_method == 'weighted': |
|
embedding = outputs.hidden_states[self.keep_max_layer] |
|
res_embedding = self.weighted_embedding(embedding, inputs['attention_mask']) |
|
elif embedding_method == 'pma': |
|
embedding = outputs.hidden_states[self.keep_max_layer] |
|
attention_mask = inputs['attention_mask'] |
|
res_embedding = self.pma_embedding(embedding, attention_mask) |
|
else: |
|
logger.debug('Error, no {} way to obtain embbedings'.format(embedding_method)) |
|
|
|
if not self.norm: |
|
res_embedding = torch.nn.functional.normalize(res_embedding, p=2.0, dim=-1, eps=1e-12, out=None) |
|
return res_embedding |
|
|
|
|
|
def encode(self, tokenizer, sentences, batch_size=32, convert_to_numpy=True, |
|
convert_to_tensor=False, show_progress_bar=True, max_seq_length=None, **kwargs): |
|
|
|
if max_seq_length is None: |
|
max_seq_length = self.inf_seq_length |
|
|
|
input_is_string = False |
|
if isinstance(sentences, str) or not hasattr(sentences, "__len__"): |
|
sentences = [sentences] |
|
input_is_string = True |
|
|
|
all_embeddings = [] |
|
length_sorted_idx = np.argsort([-len(s) for s in sentences]) |
|
sentences_sorted = [sentences[idx] for idx in length_sorted_idx] |
|
with torch.no_grad(): |
|
for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): |
|
sentences_batch = sentences_sorted[start_index: start_index + batch_size] |
|
with torch.no_grad(): |
|
inputs = tokenizer(sentences_batch, padding=True, truncation=True, max_length=max_seq_length, add_special_tokens=False, return_tensors='pt').to(self.plm_model.device) |
|
embeddings = self.get_sentence_embedding(self.embedding_method, **inputs) |
|
embeddings = embeddings.detach() |
|
if convert_to_numpy: |
|
if embeddings.dtype == torch.bfloat16: |
|
embeddings = embeddings.cpu().to(torch.float32) |
|
else: |
|
embeddings = embeddings.cpu() |
|
all_embeddings.extend(embeddings) |
|
all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] |
|
if convert_to_tensor: |
|
all_embeddings = torch.stack(all_embeddings) |
|
elif convert_to_numpy: |
|
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) |
|
|
|
if input_is_string: |
|
all_embeddings = all_embeddings[0] |
|
return all_embeddings |
|
|