|
import numpy as np |
|
import torch |
|
from tqdm import tqdm |
|
from langchain.docstore.document import Document |
|
import pandas as pd |
|
from typing import List, Dict |
|
from transformers import AutoModel, AutoTokenizer |
|
from FlagEmbedding import BGEM3FlagModel |
|
import torch |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_community.vectorstores.faiss import DistanceStrategy |
|
from langchain.docstore.document import Document |
|
from langchain_core.embeddings.embeddings import Embeddings |
|
import json |
|
|
|
with open('/mnt/huynv/messi_goat/bge_vi/legalCorpus2024/finetune_ranks.json', 'r') as f: |
|
ft_ranks = json.load(f) |
|
|
|
samples = [] |
|
with open('/mnt/huynv/messi_goat/bge_vi/legalCorpus2024/finetune_samples_true_filtered.jsonl', 'r') as f: |
|
for line in f: |
|
samples.append(json.loads(line)) |
|
|
|
documents = [] |
|
with open('/mnt/huynv/messi_goat/bge_vi/legalCorpus2024/vbpl-texts-chunked512.jsonl', 'r') as f: |
|
for line in f: |
|
documents.append(json.loads(line)) |
|
|
|
corpus = [] |
|
for document in documents: |
|
for article_idx in document['articles']: |
|
for i in range(len(document['articles'][article_idx])): |
|
chunk = { |
|
"article_id": document["law_full_id"] + "#" + article_idx, |
|
"chunk_id": i, |
|
"description": document['description'], |
|
"text": document['articles'][article_idx][i]['text'], |
|
} |
|
corpus.append(chunk) |
|
|
|
def load_json(file_path): |
|
corpus = [] |
|
with open(file_path, "r", encoding="utf-8") as file: |
|
for line in file: |
|
data = json.loads(line.strip()) |
|
corpus.append(data) |
|
return corpus |
|
|
|
def load_corpus_data(corpus: List[Dict[str, str]]) -> List[Document]: |
|
documents = [] |
|
for i in tqdm(range(len(corpus))): |
|
context = corpus[i]['text'] |
|
metadata = {'article_id': corpus[i]['article_id'], |
|
'chunk_id': corpus[i]['chunk_id'], |
|
'description': corpus[i]['description']} |
|
documents.append(Document(page_content=context, metadata=metadata)) |
|
return documents |
|
|
|
model = BGEM3FlagModel('hub/bi/stage2/checkpoint', |
|
use_fp16=False, |
|
device='cuda:3') |
|
|
|
class CustomEmbedding(Embeddings): |
|
def __init__(self, batch_size=1024): |
|
self.model = model |
|
self.batch_size = batch_size |
|
def embed_documents(self, texts: List[str]) -> List[List[float]]: |
|
embeddings = [] |
|
for i in tqdm(range(0, len(texts), self.batch_size), desc="Embedding documents"): |
|
batch_texts = texts[i:i+self.batch_size] |
|
batch_embeddings = self.get_batch_embeddings(batch_texts) |
|
embeddings.extend(batch_embeddings) |
|
torch.cuda.empty_cache() |
|
return np.vstack(embeddings) |
|
|
|
def embed_query(self, text: str) -> List[float]: |
|
embedding = self.model.encode(text, max_length=128)['dense_vecs'] |
|
return embedding |
|
|
|
def get_batch_embeddings(self, texts: List[str]) -> List[List[float]]: |
|
with torch.no_grad(): |
|
outputs = self.model.encode(texts, batch_size = self.batch_size, max_length=1024)['dense_vecs'] |
|
batch_embeddings = outputs |
|
del outputs |
|
return batch_embeddings |
|
class VectorDB: |
|
def __init__(self, |
|
documents: List[Document], |
|
vector_db=FAISS, |
|
index_path = None, |
|
embedding=CustomEmbedding() |
|
) -> None: |
|
self.vector_db = vector_db |
|
self.embedding = embedding |
|
self.index_path = index_path |
|
self.db = self._build_db(documents) |
|
|
|
def _build_db(self, documents: List[Document]): |
|
if self.index_path: |
|
db = self.vector_db.load_local(self.index_path, self.embedding, allow_dangerous_deserialization=True) |
|
else: |
|
db = self.vector_db.from_documents(documents=documents, embedding=self.embedding, distance_strategy=DistanceStrategy.DOT_PRODUCT) |
|
|
|
return db |
|
|
|
def get_retriever(self, search_type: str = "similarity", search_kwargs: dict = {"k": 10}): |
|
retriever = self.db.as_retriever(search_type=search_type, search_kwargs=search_kwargs) |
|
return retriever |
|
|
|
|
|
documents = load_corpus_data(corpus) |
|
|
|
vectordb = VectorDB(documents=documents, vector_db=FAISS, index_path=None) |
|
retriever = vectordb.get_retriever(search_type="similarity", search_kwargs={"k": 100}) |
|
|
|
def sample_handler(sample): |
|
question = sample['query'] |
|
docs = retriever.invoke(question) |
|
retrieved_article_ids_50 = [docs[i].metadata['article_id'] for i in range(50)] |
|
indexes = 100 |
|
for article_id in sample['relevant_articles']: |
|
if article_id in retrieved_article_ids_50: |
|
indexes = min(indexes, retrieved_article_ids_50.index(article_id)) |
|
if indexes == 100: |
|
return None |
|
else: |
|
pos = {'article_id': docs[indexes].metadata['article_id'], |
|
'chunk_id': docs[indexes].metadata['chunk_id']} |
|
negs = [{'article_id': docs[i].metadata['article_id'], |
|
'chunk_id': docs[i].metadata['chunk_id']} for i in range(indexes+5, 100) if docs[i].metadata['article_id'] not in sample['relevant_articles']] |
|
sample['pos'] = [pos] |
|
sample['neg'] = negs[:15] |
|
return sample |
|
|
|
count = 0 |
|
idx = 0 |
|
with open('/mnt/huynv/messi_goat/bge_vi/legalCorpus2024/finetune_samples_512_200k.jsonl', 'w', encoding='utf-8') as f: |
|
while count < 200000: |
|
new_sample = sample_handler(samples[ft_ranks[-(idx+1)]]) |
|
idx +=1 |
|
if new_sample: |
|
count += 1 |
|
json.dump(new_sample, f, ensure_ascii=False) |
|
f.write('\n') |