Spaces:
Runtime error
Runtime error
| from langchain.embeddings import OpenAIEmbeddings | |
| from typing import List, Dict | |
| from queue import PriorityQueue | |
| from dotenv import load_dotenv | |
| load_dotenv(".env") | |
| import os | |
| class Retriever: | |
| def __init__( | |
| self, openai_api_key: str = None, model: str = "text-embedding-ada-002" | |
| ): | |
| if openai_api_key is None: | |
| openai_api_key = os.environ.get("OPENAI_API_KEY") | |
| self.embed = OpenAIEmbeddings(openai_api_key=openai_api_key, model=model) | |
| self.documents = dict() | |
| def add_tool(self, tool_name: str, api_info: Dict) -> None: | |
| if tool_name in self.documents: | |
| return | |
| document = api_info["name_for_model"] + ". " + api_info["description_for_model"] | |
| document_embedding = self.embed.embed_documents([document]) | |
| self.documents[tool_name] = { | |
| "document": document, | |
| "embedding": document_embedding[0], | |
| } | |
| def query(self, query: str, topk: int = 3) -> List[str]: | |
| query_embedding = self.embed.embed_query(query) | |
| queue = PriorityQueue() | |
| for tool_name, tool_info in self.documents.items(): | |
| tool_embedding = tool_info["embedding"] | |
| tool_sim = self.similarity(query_embedding, tool_embedding) | |
| queue.put([-tool_sim, tool_name]) | |
| result = [] | |
| for i in range(min(topk, len(queue.queue))): | |
| tool = queue.get() | |
| result.append(tool[1]) | |
| return result | |
| def similarity(self, query: List[float], document: List[float]) -> float: | |
| return sum([i * j for i, j in zip(query, document)]) | |