bkaiHackathon2024 / finetune_sample512.py
coang's picture
initial commit
f889ba2
import numpy as np
import torch
from tqdm import tqdm
from langchain.docstore.document import Document
import pandas as pd
from typing import List, Dict
from transformers import AutoModel, AutoTokenizer
from FlagEmbedding import BGEM3FlagModel
import torch
from langchain_community.vectorstores import FAISS
from langchain_community.vectorstores.faiss import DistanceStrategy
from langchain.docstore.document import Document
from langchain_core.embeddings.embeddings import Embeddings
import json
with open('/mnt/huynv/messi_goat/bge_vi/legalCorpus2024/finetune_ranks.json', 'r') as f:
ft_ranks = json.load(f)
samples = []
with open('/mnt/huynv/messi_goat/bge_vi/legalCorpus2024/finetune_samples_true_filtered.jsonl', 'r') as f:
for line in f:
samples.append(json.loads(line))
documents = []
with open('/mnt/huynv/messi_goat/bge_vi/legalCorpus2024/vbpl-texts-chunked512.jsonl', 'r') as f:
for line in f:
documents.append(json.loads(line))
corpus = []
for document in documents:
for article_idx in document['articles']:
for i in range(len(document['articles'][article_idx])):
chunk = {
"article_id": document["law_full_id"] + "#" + article_idx,
"chunk_id": i,
"description": document['description'],
"text": document['articles'][article_idx][i]['text'],
}
corpus.append(chunk)
def load_json(file_path):
corpus = []
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
data = json.loads(line.strip())
corpus.append(data)
return corpus
def load_corpus_data(corpus: List[Dict[str, str]]) -> List[Document]:
documents = []
for i in tqdm(range(len(corpus))):
context = corpus[i]['text']
metadata = {'article_id': corpus[i]['article_id'],
'chunk_id': corpus[i]['chunk_id'],
'description': corpus[i]['description']}
documents.append(Document(page_content=context, metadata=metadata))
return documents
model = BGEM3FlagModel('hub/bi/stage2/checkpoint',
use_fp16=False,
device='cuda:3')
class CustomEmbedding(Embeddings):
def __init__(self, batch_size=1024):
self.model = model
self.batch_size = batch_size
def embed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = []
for i in tqdm(range(0, len(texts), self.batch_size), desc="Embedding documents"):
batch_texts = texts[i:i+self.batch_size]
batch_embeddings = self.get_batch_embeddings(batch_texts)
embeddings.extend(batch_embeddings)
torch.cuda.empty_cache()
return np.vstack(embeddings)
def embed_query(self, text: str) -> List[float]:
embedding = self.model.encode(text, max_length=128)['dense_vecs']
return embedding
def get_batch_embeddings(self, texts: List[str]) -> List[List[float]]:
with torch.no_grad():
outputs = self.model.encode(texts, batch_size = self.batch_size, max_length=1024)['dense_vecs']
batch_embeddings = outputs
del outputs
return batch_embeddings
class VectorDB:
def __init__(self,
documents: List[Document],
vector_db=FAISS,
index_path = None,
embedding=CustomEmbedding()
) -> None:
self.vector_db = vector_db
self.embedding = embedding
self.index_path = index_path
self.db = self._build_db(documents)
def _build_db(self, documents: List[Document]):
if self.index_path:
db = self.vector_db.load_local(self.index_path, self.embedding, allow_dangerous_deserialization=True)
else:
db = self.vector_db.from_documents(documents=documents, embedding=self.embedding, distance_strategy=DistanceStrategy.DOT_PRODUCT)
#db.save_local('stage2_bi_index')
return db
def get_retriever(self, search_type: str = "similarity", search_kwargs: dict = {"k": 10}):
retriever = self.db.as_retriever(search_type=search_type, search_kwargs=search_kwargs)
return retriever
documents = load_corpus_data(corpus)
# vector DB
vectordb = VectorDB(documents=documents, vector_db=FAISS, index_path=None)
retriever = vectordb.get_retriever(search_type="similarity", search_kwargs={"k": 100})
def sample_handler(sample):
question = sample['query']
docs = retriever.invoke(question)
retrieved_article_ids_50 = [docs[i].metadata['article_id'] for i in range(50)]
indexes = 100
for article_id in sample['relevant_articles']:
if article_id in retrieved_article_ids_50:
indexes = min(indexes, retrieved_article_ids_50.index(article_id))
if indexes == 100:
return None
else:
pos = {'article_id': docs[indexes].metadata['article_id'],
'chunk_id': docs[indexes].metadata['chunk_id']}
negs = [{'article_id': docs[i].metadata['article_id'],
'chunk_id': docs[i].metadata['chunk_id']} for i in range(indexes+5, 100) if docs[i].metadata['article_id'] not in sample['relevant_articles']]
sample['pos'] = [pos]
sample['neg'] = negs[:15]
return sample
count = 0
idx = 0
with open('/mnt/huynv/messi_goat/bge_vi/legalCorpus2024/finetune_samples_512_200k.jsonl', 'w', encoding='utf-8') as f:
while count < 200000:
new_sample = sample_handler(samples[ft_ranks[-(idx+1)]])
idx +=1
if new_sample:
count += 1
json.dump(new_sample, f, ensure_ascii=False)
f.write('\n')