File size: 2,738 Bytes
cd6f98e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from __future__ import annotations

import uuid
from typing import Any, Dict, List

from langchain.embeddings import OpenAIEmbeddings
from langchain.embeddings.base import Embeddings
from pinecone import Index  # import doesnt work on plane wifi
from pydantic import BaseModel

from reworkd_platform.settings import settings
from reworkd_platform.timer import timed_function
from reworkd_platform.web.api.memory.memory import AgentMemory

OPENAI_EMBEDDING_DIM = 1536


class Row(BaseModel):
    id: str
    values: List[float]
    metadata: Dict[str, Any] = {}


class QueryResult(BaseModel):
    id: str
    score: float
    metadata: Dict[str, Any] = {}


class PineconeMemory(AgentMemory):
    """
    Wrapper around pinecone
    """

    def __init__(self, index_name: str, namespace: str = ""):
        self.index = Index(settings.pinecone_index_name)
        self.namespace = namespace or index_name

    @timed_function(level="DEBUG")
    def __enter__(self) -> AgentMemory:
        self.embeddings: Embeddings = OpenAIEmbeddings(
            client=None,  # Meta private value but mypy will complain its missing
            openai_api_key=settings.openai_api_key,
        )

        return self

    def __exit__(self, *args: Any, **kwargs: Any) -> None:
        pass

    @timed_function(level="DEBUG")
    def reset_class(self) -> None:
        self.index.delete(delete_all=True, namespace=self.namespace)

    @timed_function(level="DEBUG")
    def add_tasks(self, tasks: List[str]) -> List[str]:
        if len(tasks) == 0:
            return []

        embeds = self.embeddings.embed_documents(tasks)

        if len(tasks) != len(embeds):
            raise ValueError("Embeddings and tasks are not the same length")

        rows = [
            Row(values=vector, metadata={"text": tasks[i]}, id=str(uuid.uuid4()))
            for i, vector in enumerate(embeds)
        ]

        self.index.upsert(
            vectors=[row.dict() for row in rows], namespace=self.namespace
        )

        return [row.id for row in rows]

    @timed_function(level="DEBUG")
    def get_similar_tasks(
        self, text: str, score_threshold: float = 0.95
    ) -> List[QueryResult]:
        # Get similar tasks
        vector = self.embeddings.embed_query(text)
        results = self.index.query(
            vector=vector,
            top_k=5,
            include_metadata=True,
            include_values=True,
            namespace=self.namespace,
        )

        return [
            QueryResult(id=row.id, score=row.score, metadata=row.metadata)
            for row in getattr(results, "matches", [])
            if row.score > score_threshold
        ]

    @staticmethod
    def should_use() -> bool:
        return False