rag / gemini_embedding.py
jessica45's picture
initial commit
8953dfc verified
raw
history blame contribute delete
749 Bytes
import os
import google.generativeai as genai
from chromadb.api.types import Documents, Embeddings
from chromadb import EmbeddingFunction
from dotenv import load_dotenv
load_dotenv()
gemini_api_key = os.environ["GEMINI_API_KEY"]
class GeminiEmbeddingFunction(EmbeddingFunction):
"""
Custom embedding function using Gemini AI API.
"""
def __call__(self, input: Documents) -> Embeddings:
if not gemini_api_key:
raise ValueError("Gemini API Key not provided. Please set GEMINI_API_KEY as an environment variable.")
genai.configure(api_key=gemini_api_key)
model = "models/text-embedding-004"
return genai.embed_content(model=model, content=input, task_type="retrieval_document")["embedding"]