EQ3A2A's picture
Upload folder using huggingface_hub
d195d4f
raw
history blame
3.17 kB
import asyncio
from factool.knowledge_qa.google_serper import GoogleSerperAPIWrapper
from factool.utils.openai_wrapper import OpenAIEmbed
import json
import os
import numpy as np
import jsonlines
import pdb
class google_search():
def __init__(self, snippet_cnt):
self.serper = GoogleSerperAPIWrapper(snippet_cnt=snippet_cnt)
async def run(self, queries):
return await self.serper.run(queries)
class local_search():
def __init__(self, snippet_cnt, data_link, embedding_link=None):
self.snippet_cnt = snippet_cnt
self.data_link = data_link
self.embedding_link = embedding_link
self.openai_embed = OpenAIEmbed()
self.data = None
self.embedding = None
asyncio.run(self.init_async())
async def init_async(self):
print("init local search")
self.load_data_by_link()
if self.embedding_link is None:
await self.calculate_embedding()
else:
self.load_embedding_by_link()
print("loaded data and embedding")
def add_suffix_to_json_filename(self, filename):
base_name, extension = os.path.splitext(filename)
return base_name + '_embed' + extension
def load_data_by_link(self):
#load data from json link
self.data = []
#self.data = json.load(open(self.data_link, 'r'))
with jsonlines.open(self.data_link) as reader:
for obj in reader:
self.data.append(obj['text'])
def load_embedding_by_link(self):
self.embedding = []
#self.embedding = json.load(open(self.embedding_link, 'r'))
with jsonlines.open(self.embedding_link) as reader:
for obj in reader:
self.embedding.append(obj)
def save_embeddings(self):
#json.dump(self.embedding, open(self.add_suffix_to_json_filename(self.data_link), 'w'))
with jsonlines.open(self.add_suffix_to_json_filename(self.data_link), mode='w') as writer:
writer.write_all(self.embedding)
async def calculate_embedding(self):
result = await self.openai_embed.process_batch(self.data,retry=3)
self.embedding = [emb["data"][0]["embedding"] for emb in result]
self.save_embeddings()
async def search(self, query):
result = await self.openai_embed.create_embedding(query)
query_embed = result["data"][0]["embedding"]
dot_product = np.dot(self.embedding, query_embed)
sorted_indices = np.argsort(dot_product)[::-1]
top_k_indices = sorted_indices[:self.snippet_cnt]
return [{"content":self.data[i],"source":"local"} for i in top_k_indices]
async def run(self, queries):
flattened_queries = []
for sublist in queries:
if sublist is None:
sublist = ['None', 'None']
for item in sublist:
flattened_queries.append(item)
snippets = await asyncio.gather(*[self.search(query) for query in flattened_queries])
snippets_split = [snippets[i] + snippets[i+1] for i in range(0, len(snippets), 2)]
return snippets_split