Spaces:
Runtime error
Runtime error
from typing import Any, Dict, List, Optional | |
from langchain_core.embeddings import Embeddings | |
from langchain_core.pydantic_v1 import ( | |
BaseModel, | |
Extra, | |
Field, | |
SecretStr, | |
root_validator, | |
) | |
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env | |
from langchain_community.utilities.requests import Requests | |
class EdenAiEmbeddings(BaseModel, Embeddings): | |
"""EdenAI embedding. | |
environment variable ``EDENAI_API_KEY`` set with your API key, or pass | |
it as a named parameter. | |
""" | |
edenai_api_key: Optional[SecretStr] = Field(None, description="EdenAI API Token") | |
provider: str = "openai" | |
"""embedding provider to use (eg: openai,google etc.)""" | |
model: Optional[str] = None | |
""" | |
model name for above provider (eg: 'gpt-3.5-turbo-instruct' for openai) | |
available models are shown on https://docs.edenai.co/ under 'available providers' | |
""" | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = Extra.forbid | |
def validate_environment(cls, values: Dict) -> Dict: | |
"""Validate that api key exists in environment.""" | |
values["edenai_api_key"] = convert_to_secret_str( | |
get_from_dict_or_env(values, "edenai_api_key", "EDENAI_API_KEY") | |
) | |
return values | |
def get_user_agent() -> str: | |
from langchain_community import __version__ | |
return f"langchain/{__version__}" | |
def _generate_embeddings(self, texts: List[str]) -> List[List[float]]: | |
"""Compute embeddings using EdenAi api.""" | |
url = "https://api.edenai.run/v2/text/embeddings" | |
headers = { | |
"accept": "application/json", | |
"content-type": "application/json", | |
"authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr] | |
"User-Agent": self.get_user_agent(), | |
} | |
payload: Dict[str, Any] = {"texts": texts, "providers": self.provider} | |
if self.model is not None: | |
payload["settings"] = {self.provider: self.model} | |
request = Requests(headers=headers) | |
response = request.post(url=url, data=payload) | |
if response.status_code >= 500: | |
raise Exception(f"EdenAI Server: Error {response.status_code}") | |
elif response.status_code >= 400: | |
raise ValueError(f"EdenAI received an invalid payload: {response.text}") | |
elif response.status_code != 200: | |
raise Exception( | |
f"EdenAI returned an unexpected response with status " | |
f"{response.status_code}: {response.text}" | |
) | |
temp = response.json() | |
provider_response = temp[self.provider] | |
if provider_response.get("status") == "fail": | |
err_msg = provider_response.get("error", {}).get("message") | |
raise Exception(err_msg) | |
embeddings = [] | |
for embed_item in temp[self.provider]["items"]: | |
embedding = embed_item["embedding"] | |
embeddings.append(embedding) | |
return embeddings | |
def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
"""Embed a list of documents using EdenAI. | |
Args: | |
texts: The list of texts to embed. | |
Returns: | |
List of embeddings, one for each text. | |
""" | |
return self._generate_embeddings(texts) | |
def embed_query(self, text: str) -> List[float]: | |
"""Embed a query using EdenAI. | |
Args: | |
text: The text to embed. | |
Returns: | |
Embeddings for the text. | |
""" | |
return self._generate_embeddings([text])[0] | |