anpigon's picture
add langchain docs
ed4d993
raw
history blame
9.45 kB
"""Wrapper around EdenAI's Generation API."""
import logging
from typing import Any, Dict, List, Literal, Optional
from aiohttp import ClientSession
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_community.llms.utils import enforce_stop_tokens
from langchain_community.utilities.requests import Requests
logger = logging.getLogger(__name__)
class EdenAI(LLM):
"""EdenAI models.
To use, you should have
the environment variable ``EDENAI_API_KEY`` set with your API token.
You can find your token here: https://app.edenai.run/admin/account/settings
`feature` and `subfeature` are required, but any other model parameters can also be
passed in with the format params={model_param: value, ...}
for api reference check edenai documentation: http://docs.edenai.co.
"""
base_url: str = "https://api.edenai.run/v2"
edenai_api_key: Optional[str] = None
feature: Literal["text", "image"] = "text"
"""Which generative feature to use, use text by default"""
subfeature: Literal["generation"] = "generation"
"""Subfeature of above feature, use generation by default"""
provider: str
"""Generative provider to use (eg: openai,stabilityai,cohere,google etc.)"""
model: Optional[str] = None
"""
model name for above provider (eg: 'gpt-3.5-turbo-instruct' for openai)
available models are shown on https://docs.edenai.co/ under 'available providers'
"""
# Optional parameters to add depending of chosen feature
# see api reference for more infos
temperature: Optional[float] = Field(default=None, ge=0, le=1) # for text
max_tokens: Optional[int] = Field(default=None, ge=0) # for text
resolution: Optional[Literal["256x256", "512x512", "1024x1024"]] = None # for image
params: Dict[str, Any] = Field(default_factory=dict)
"""
DEPRECATED: use temperature, max_tokens, resolution directly
optional parameters to pass to api
"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""extra parameters"""
stop_sequences: Optional[List[str]] = None
"""Stop sequences to use."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
values["edenai_api_key"] = get_from_dict_or_env(
values, "edenai_api_key", "EDENAI_API_KEY"
)
return values
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()}
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name not in all_required_field_names:
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
logger.warning(
f"""{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
values["model_kwargs"] = extra
return values
@property
def _llm_type(self) -> str:
"""Return type of model."""
return "edenai"
def _format_output(self, output: dict) -> str:
if self.feature == "text":
return output[self.provider]["generated_text"]
else:
return output[self.provider]["items"][0]["image"]
@staticmethod
def get_user_agent() -> str:
from langchain_community import __version__
return f"langchain/{__version__}"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to EdenAI's text generation endpoint.
Args:
prompt: The prompt to pass into the model.
Returns:
json formatted str response.
"""
stops = None
if self.stop_sequences is not None and stop is not None:
raise ValueError(
"stop sequences found in both the input and default params."
)
elif self.stop_sequences is not None:
stops = self.stop_sequences
else:
stops = stop
url = f"{self.base_url}/{self.feature}/{self.subfeature}"
headers = {
"Authorization": f"Bearer {self.edenai_api_key}",
"User-Agent": self.get_user_agent(),
}
payload: Dict[str, Any] = {
"providers": self.provider,
"text": prompt,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"resolution": self.resolution,
**self.params,
**kwargs,
"num_images": 1, # always limit to 1 (ignored for text)
}
# filter None values to not pass them to the http payload
payload = {k: v for k, v in payload.items() if v is not None}
if self.model is not None:
payload["settings"] = {self.provider: self.model}
request = Requests(headers=headers)
response = request.post(url=url, data=payload)
if response.status_code >= 500:
raise Exception(f"EdenAI Server: Error {response.status_code}")
elif response.status_code >= 400:
raise ValueError(f"EdenAI received an invalid payload: {response.text}")
elif response.status_code != 200:
raise Exception(
f"EdenAI returned an unexpected response with status "
f"{response.status_code}: {response.text}"
)
data = response.json()
provider_response = data[self.provider]
if provider_response.get("status") == "fail":
err_msg = provider_response.get("error", {}).get("message")
raise Exception(err_msg)
output = self._format_output(data)
if stops is not None:
output = enforce_stop_tokens(output, stops)
return output
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call EdenAi model to get predictions based on the prompt.
Args:
prompt: The prompt to pass into the model.
stop: A list of stop words (optional).
run_manager: A callback manager for async interaction with LLMs.
Returns:
The string generated by the model.
"""
stops = None
if self.stop_sequences is not None and stop is not None:
raise ValueError(
"stop sequences found in both the input and default params."
)
elif self.stop_sequences is not None:
stops = self.stop_sequences
else:
stops = stop
url = f"{self.base_url}/{self.feature}/{self.subfeature}"
headers = {
"Authorization": f"Bearer {self.edenai_api_key}",
"User-Agent": self.get_user_agent(),
}
payload: Dict[str, Any] = {
"providers": self.provider,
"text": prompt,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"resolution": self.resolution,
**self.params,
**kwargs,
"num_images": 1, # always limit to 1 (ignored for text)
}
# filter `None` values to not pass them to the http payload as null
payload = {k: v for k, v in payload.items() if v is not None}
if self.model is not None:
payload["settings"] = {self.provider: self.model}
async with ClientSession() as session:
async with session.post(url, json=payload, headers=headers) as response:
if response.status >= 500:
raise Exception(f"EdenAI Server: Error {response.status}")
elif response.status >= 400:
raise ValueError(
f"EdenAI received an invalid payload: {response.text}"
)
elif response.status != 200:
raise Exception(
f"EdenAI returned an unexpected response with status "
f"{response.status}: {response.text}"
)
response_json = await response.json()
provider_response = response_json[self.provider]
if provider_response.get("status") == "fail":
err_msg = provider_response.get("error", {}).get("message")
raise Exception(err_msg)
output = self._format_output(response_json)
if stops is not None:
output = enforce_stop_tokens(output, stops)
return output