import numpy as np import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from transformers import AutoModel, AutoConfig from transformers import AutoTokenizer from optimum.onnxruntime import ORTModel import pandas as pd import os AUTH_TOKEN = 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 tokenizer = AutoTokenizer.from_pretrained('nguyenvulebinh/vi-mrc-base', use_auth_token=AUTH_TOKEN) pad_token_id = tokenizer.pad_token_id class PairwiseModel(nn.Module): def __init__(self, model_name, max_length=384, batch_size=16, device="cpu"): super(PairwiseModel, self).__init__() self.max_length = max_length self.batch_size = batch_size self.device = device self.model = ORTModel.from_pretrained(model_name, use_auth_token=AUTH_TOKEN, from_transformers=True) self.model.to(self.device) self.model.eval() self.config = AutoConfig.from_pretrained(model_name, use_auth_token=AUTH_TOKEN) self.fc = nn.Linear(768, 1).to(self.device) def forward(self, ids, masks): out = self.model(input_ids=ids, attention_mask=masks, output_hidden_states=False).last_hidden_state out = out[:, 0] outputs = self.fc(out) return outputs def stage1_ranking(self, question, texts): tmp = pd.DataFrame() tmp["text"] = [" ".join(x.split()) for x in texts] tmp["question"] = question valid_dataset = SiameseDatasetStage1(tmp, tokenizer, self.max_length, is_test=True) valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size, collate_fn=collate_fn, num_workers=0, shuffle=False, pin_memory=True) preds = [] with torch.no_grad(): bar = enumerate(valid_loader) for step, data in bar: ids = data["ids"].to(self.device) masks = data["masks"].to(self.device) preds.append(torch.sigmoid(self(ids, masks)).view(-1)) preds = torch.concat(preds) return preds.cpu().numpy() def stage2_ranking(self, question, answer, titles, texts): tmp = pd.DataFrame() tmp["candidate"] = texts tmp["question"] = question tmp["answer"] = answer tmp["title"] = titles valid_dataset = SiameseDatasetStage2(tmp, tokenizer, self.max_length, is_test=True) valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size, collate_fn=collate_fn, num_workers=0, shuffle=False, pin_memory=True) preds = [] with torch.no_grad(): bar = enumerate(valid_loader) for step, data in bar: ids = data["ids"].to(self.device) masks = data["masks"].to(self.device) preds.append(torch.sigmoid(self(ids, masks)).view(-1)) preds = torch.concat(preds) return preds.cpu().numpy() class SiameseDatasetStage1(Dataset): def __init__(self, df, tokenizer, max_length, is_test=False): self.df = df self.max_length = max_length self.tokenizer = tokenizer self.is_test = is_test self.content1 = tokenizer.batch_encode_plus(list(df.question.values), max_length=max_length, truncation=True)[ "input_ids"] self.content2 = tokenizer.batch_encode_plus(list(df.text.values), max_length=max_length, truncation=True)[ "input_ids"] if not self.is_test: self.targets = self.df.label def __len__(self): return len(self.df) def __getitem__(self, index): return { 'ids1': torch.tensor(self.content1[index], dtype=torch.long), 'ids2': torch.tensor(self.content2[index][1:], dtype=torch.long), 'target': torch.tensor(0) if self.is_test else torch.tensor(self.targets[index], dtype=torch.float) } class SiameseDatasetStage2(Dataset): def __init__(self, df, tokenizer, max_length, is_test=False): self.df = df self.max_length = max_length self.tokenizer = tokenizer self.is_test = is_test self.df["content1"] = self.df.apply(lambda row: row.question + f" {tokenizer.sep_token} " + row.answer, axis=1) self.df["content2"] = self.df.apply(lambda row: row.title + f" {tokenizer.sep_token} " + row.candidate, axis=1) self.content1 = tokenizer.batch_encode_plus(list(df.content1.values), max_length=max_length, truncation=True)[ "input_ids"] self.content2 = tokenizer.batch_encode_plus(list(df.content2.values), max_length=max_length, truncation=True)[ "input_ids"] if not self.is_test: self.targets = self.df.label def __len__(self): return len(self.df) def __getitem__(self, index): return { 'ids1': torch.tensor(self.content1[index], dtype=torch.long), 'ids2': torch.tensor(self.content2[index][1:], dtype=torch.long), 'target': torch.tensor(0) if self.is_test else torch.tensor(self.targets[index], dtype=torch.float) } def collate_fn(batch): ids = [torch.cat([x["ids1"], x["ids2"]]) for x in batch] targets = [x["target"] for x in batch] max_len = np.max([len(x) for x in ids]) masks = [] for i in range(len(ids)): if len(ids[i]) < max_len: ids[i] = torch.cat((ids[i], torch.tensor([pad_token_id, ] * (max_len - len(ids[i])), dtype=torch.long))) masks.append(ids[i] != pad_token_id) # print(tokenizer.decode(ids[0])) outputs = { "ids": torch.vstack(ids), "masks": torch.vstack(masks), "target": torch.vstack(targets).view(-1) } return outputs