File size: 2,119 Bytes
8a41f4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from abc import ABC, abstractmethod
import os
from langchain.embeddings import (
    OpenAIEmbeddings,
    HuggingFaceEmbeddings,
    CohereEmbeddings,
    HuggingFaceInstructEmbeddings,
)
from application.core.settings import settings

class BaseVectorStore(ABC):
    def __init__(self):
        pass

    @abstractmethod
    def search(self, *args, **kwargs):
        pass

    def is_azure_configured(self):
        return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME

    def _get_embeddings(self, embeddings_name, embeddings_key=None):
        embeddings_factory = {
            "openai_text-embedding-ada-002": OpenAIEmbeddings,
            "huggingface_sentence-transformers/all-mpnet-base-v2": HuggingFaceEmbeddings,
            "huggingface_hkunlp/instructor-large": HuggingFaceInstructEmbeddings,
            "cohere_medium": CohereEmbeddings
        }
        
        if embeddings_name not in embeddings_factory:
            raise ValueError(f"Invalid embeddings_name: {embeddings_name}")

        if embeddings_name == "openai_text-embedding-ada-002":
            if self.is_azure_configured():
                os.environ["OPENAI_API_TYPE"] = "azure"
                embedding_instance = embeddings_factory[embeddings_name](
                    model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
                )
            else:
                embedding_instance = embeddings_factory[embeddings_name](
                    openai_api_key=embeddings_key
                )
        elif embeddings_name == "cohere_medium":
            embedding_instance = embeddings_factory[embeddings_name](
                cohere_api_key=embeddings_key
            )
        elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
            embedding_instance = embeddings_factory[embeddings_name](
                model_name="./model/all-mpnet-base-v2",
                model_kwargs={"device": "cpu"},
            )
        else:
            embedding_instance = embeddings_factory[embeddings_name]()
            
        return embedding_instance