Spaces:
Runtime error
Runtime error
File size: 4,356 Bytes
ed4d993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
from typing import Any, Dict, List, Optional, Union, cast
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_community.utilities.arcee import ArceeWrapper, DALMFilter
class Arcee(LLM):
"""Arcee's Domain Adapted Language Models (DALMs).
To use, set the ``ARCEE_API_KEY`` environment variable with your Arcee API key,
or pass ``arcee_api_key`` as a named parameter.
Example:
.. code-block:: python
from langchain_community.llms import Arcee
arcee = Arcee(
model="DALM-PubMed",
arcee_api_key="ARCEE-API-KEY"
)
response = arcee("AI-driven music therapy")
"""
_client: Optional[ArceeWrapper] = None #: :meta private:
"""Arcee _client."""
arcee_api_key: Union[SecretStr, str, None] = None
"""Arcee API Key"""
model: str
"""Arcee DALM name"""
arcee_api_url: str = "https://api.arcee.ai"
"""Arcee API URL"""
arcee_api_version: str = "v2"
"""Arcee API Version"""
arcee_app_url: str = "https://app.arcee.ai"
"""Arcee App URL"""
model_id: str = ""
"""Arcee Model ID"""
model_kwargs: Optional[Dict[str, Any]] = None
"""Keyword arguments to pass to the model."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
underscore_attrs_are_private = True
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "arcee"
def __init__(self, **data: Any) -> None:
"""Initializes private fields."""
super().__init__(**data)
api_key = cast(SecretStr, self.arcee_api_key)
self._client = ArceeWrapper(
arcee_api_key=api_key,
arcee_api_url=self.arcee_api_url,
arcee_api_version=self.arcee_api_version,
model_kwargs=self.model_kwargs,
model_name=self.model,
)
@root_validator(pre=False)
def validate_environments(cls, values: Dict) -> Dict:
"""Validate Arcee environment variables."""
# validate env vars
values["arcee_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"arcee_api_key",
"ARCEE_API_KEY",
)
)
values["arcee_api_url"] = get_from_dict_or_env(
values,
"arcee_api_url",
"ARCEE_API_URL",
)
values["arcee_app_url"] = get_from_dict_or_env(
values,
"arcee_app_url",
"ARCEE_APP_URL",
)
values["arcee_api_version"] = get_from_dict_or_env(
values,
"arcee_api_version",
"ARCEE_API_VERSION",
)
# validate model kwargs
if values.get("model_kwargs"):
kw = values["model_kwargs"]
# validate size
if kw.get("size") is not None:
if not kw.get("size") >= 0:
raise ValueError("`size` must be positive")
# validate filters
if kw.get("filters") is not None:
if not isinstance(kw.get("filters"), List):
raise ValueError("`filters` must be a list")
for f in kw.get("filters"):
DALMFilter(**f)
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Generate text from Arcee DALM.
Args:
prompt: Prompt to generate text from.
size: The max number of context results to retrieve.
Defaults to 3. (Can be less if filters are provided).
filters: Filters to apply to the context dataset.
"""
try:
if not self._client:
raise ValueError("Client is not initialized.")
return self._client.generate(prompt=prompt, **kwargs)
except Exception as e:
raise Exception(f"Failed to generate text: {e}") from e
|