bkaiHackathon2024 / 04_prepare_hard_neg_cross.py
coang's picture
initial commit
f889ba2
import numpy as np
import torch
from tqdm import tqdm
from langchain.docstore.document import Document
import pandas as pd
from typing import List, Dict
from transformers import AutoModel, AutoTokenizer
from FlagEmbedding import BGEM3FlagModel
import torch
from langchain_community.vectorstores import FAISS
from langchain_community.vectorstores.faiss import DistanceStrategy
from langchain.docstore.document import Document
from langchain_core.embeddings.embeddings import Embeddings
import json
def load_json(file_path):
corpus = []
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
data = json.loads(line.strip())
corpus.append(data)
return corpus
def load_corpus_data(corpus: List[Dict[str, str]]) -> List[Document]:
documents = []
for i in tqdm(range(len(corpus))):
context = corpus[i]['text']
metadata = {'sid': corpus[i]['sid'],
'cid': corpus[i]['cid']}
documents.append(Document(page_content=context, metadata=metadata))
return documents
model = BGEM3FlagModel('hub/bi/stage2/checkpoint',
use_fp16=False)
class CustomEmbedding(Embeddings):
def __init__(self, batch_size=4096):
self.model = model
self.batch_size = batch_size
def embed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = []
for i in tqdm(range(0, len(texts), self.batch_size), desc="Embedding documents"):
batch_texts = texts[i:i+self.batch_size]
batch_embeddings = self.get_batch_embeddings(batch_texts)
embeddings.extend(batch_embeddings)
torch.cuda.empty_cache()
return np.vstack(embeddings)
def embed_query(self, text: str) -> List[float]:
embedding = self.model.encode(text, max_length=128)['dense_vecs']
return embedding
def get_batch_embeddings(self, texts: List[str]) -> List[List[float]]:
with torch.no_grad():
outputs = self.model.encode(texts, batch_size = self.batch_size, max_length=1024)['dense_vecs']
batch_embeddings = outputs
del outputs
return batch_embeddings
class VectorDB:
def __init__(self,
documents: List[Document],
vector_db=FAISS,
index_path = None,
embedding=CustomEmbedding()
) -> None:
self.vector_db = vector_db
self.embedding = embedding
self.index_path = index_path
self.db = self._build_db(documents)
def _build_db(self, documents: List[Document]):
if self.index_path:
db = self.vector_db.load_local(self.index_path, self.embedding, allow_dangerous_deserialization=True)
else:
db = self.vector_db.from_documents(documents=documents, embedding=self.embedding, distance_strategy=DistanceStrategy.DOT_PRODUCT)
#db.save_local('stage2_bi_index')
return db
def get_retriever(self, search_type: str = "similarity", search_kwargs: dict = {"k": 10}):
retriever = self.db.as_retriever(search_type=search_type, search_kwargs=search_kwargs)
return retriever
file_path = "data/processed/corpus_chunks1024.jsonl"
corpus = load_json(file_path)
documents = load_corpus_data(corpus)
# vector DB
vectordb = VectorDB(documents=documents, vector_db=FAISS, index_path='hub/bi/stage2/index')
retriever = vectordb.get_retriever(search_type="similarity", search_kwargs={"k": 20})
file_path = 'data/processed/final_data_train_chunk_1024.jsonl'
data_train = load_json(file_path)
final_dataset = []
for sample in data_train:
if sample['context'] is None or len(sample['context']) == 0:
continue
else:
final_dataset.append(sample)
no_negs = 11
final_list = []
for i in tqdm(range(len(final_dataset))):
diction = {}
question = final_dataset[i]['question']
diction['query'] = question
sids = final_dataset[i]['sid']
docs = retriever.invoke(question)
neg_docs = [doc for doc in docs if doc.metadata['sid'] not in sids][:no_negs]
hard_negative = [neg_doc.page_content for neg_doc in neg_docs]
diction['pos'] = final_dataset[i]['context']
diction['neg'] = hard_negative
final_list.append(diction)
file_path = 'data/cross/train_1024_chunks_cross.jsonl'
with open(file_path, 'w', encoding='utf-8') as f:
for item in final_list:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
print("Dữ liệu đã được lưu vào file JSON.")