File size: 1,230 Bytes
37b6839
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from typing import List, Any
import google.generativeai as genai
from llama_index.core.embeddings import BaseEmbedding

class GEmbeddings(BaseEmbedding):
    def __init__(
        self,
        model_name: str = 'models/text-embedding-004',
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        self._model_name = model_name

    def gai_embed_content(self, text: str) -> List[float]:
        return genai.embed_content(model=self._model_name, content=text)

    def _get_query_embedding(self, query: str) -> List[float]:
        embeddings = self.gai_embed_content(query)
        return embeddings['embedding']

    def _get_text_embedding(self, text: str) -> List[float]:
        embeddings = self.gai_embed_content(text)
        return embeddings['embedding']

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        embeddings = [
            self.gai_embed_content(text)['embedding'] for text in texts
        ]
        return embeddings

    async def _aget_query_embedding(self, query: str) -> List[float]:
        return self._get_query_embedding(query)

    async def _aget_text_embedding(self, text: str) -> List[float]:
        return self._get_text_embedding(text)