Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import uuid | |
from typing import Any, Dict, List | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.embeddings.base import Embeddings | |
from pinecone import Index # import doesnt work on plane wifi | |
from pydantic import BaseModel | |
from reworkd_platform.settings import settings | |
from reworkd_platform.timer import timed_function | |
from reworkd_platform.web.api.memory.memory import AgentMemory | |
OPENAI_EMBEDDING_DIM = 1536 | |
class Row(BaseModel): | |
id: str | |
values: List[float] | |
metadata: Dict[str, Any] = {} | |
class QueryResult(BaseModel): | |
id: str | |
score: float | |
metadata: Dict[str, Any] = {} | |
class PineconeMemory(AgentMemory): | |
""" | |
Wrapper around pinecone | |
""" | |
def __init__(self, index_name: str, namespace: str = ""): | |
self.index = Index(settings.pinecone_index_name) | |
self.namespace = namespace or index_name | |
def __enter__(self) -> AgentMemory: | |
self.embeddings: Embeddings = OpenAIEmbeddings( | |
client=None, # Meta private value but mypy will complain its missing | |
openai_api_key=settings.openai_api_key, | |
) | |
return self | |
def __exit__(self, *args: Any, **kwargs: Any) -> None: | |
pass | |
def reset_class(self) -> None: | |
self.index.delete(delete_all=True, namespace=self.namespace) | |
def add_tasks(self, tasks: List[str]) -> List[str]: | |
if len(tasks) == 0: | |
return [] | |
embeds = self.embeddings.embed_documents(tasks) | |
if len(tasks) != len(embeds): | |
raise ValueError("Embeddings and tasks are not the same length") | |
rows = [ | |
Row(values=vector, metadata={"text": tasks[i]}, id=str(uuid.uuid4())) | |
for i, vector in enumerate(embeds) | |
] | |
self.index.upsert( | |
vectors=[row.dict() for row in rows], namespace=self.namespace | |
) | |
return [row.id for row in rows] | |
def get_similar_tasks( | |
self, text: str, score_threshold: float = 0.95 | |
) -> List[QueryResult]: | |
# Get similar tasks | |
vector = self.embeddings.embed_query(text) | |
results = self.index.query( | |
vector=vector, | |
top_k=5, | |
include_metadata=True, | |
include_values=True, | |
namespace=self.namespace, | |
) | |
return [ | |
QueryResult(id=row.id, score=row.score, metadata=row.metadata) | |
for row in getattr(results, "matches", []) | |
if row.score > score_threshold | |
] | |
def should_use() -> bool: | |
return False | |