Spaces:
Running
Running
import asyncio | |
from functools import partial | |
from typing import ( | |
Any, | |
List, | |
Mapping, | |
Optional, | |
) | |
from ai21.models import CompletionsResponse | |
from langchain_core.callbacks import ( | |
AsyncCallbackManagerForLLMRun, | |
CallbackManagerForLLMRun, | |
) | |
from langchain_core.language_models import BaseLLM | |
from langchain_core.outputs import Generation, LLMResult | |
from langchain_ai21.ai21_base import AI21Base | |
class AI21LLM(BaseLLM, AI21Base): | |
"""AI21LLM large language models. | |
Example: | |
.. code-block:: python | |
from langchain_ai21 import AI21LLM | |
model = AI21LLM() | |
""" | |
model: str | |
"""Model type you wish to interact with. | |
You can view the options at https://github.com/AI21Labs/ai21-python?tab=readme-ov-file#model-types""" | |
num_results: int = 1 | |
"""The number of responses to generate for a given prompt.""" | |
max_tokens: int = 16 | |
"""The maximum number of tokens to generate for each response.""" | |
min_tokens: int = 0 | |
"""The minimum number of tokens to generate for each response.""" | |
temperature: float = 0.7 | |
"""A value controlling the "creativity" of the model's responses.""" | |
top_p: float = 1 | |
"""A value controlling the diversity of the model's responses.""" | |
top_k_return: int = 0 | |
"""The number of top-scoring tokens to consider for each generation step.""" | |
frequency_penalty: Optional[Any] = None | |
"""A penalty applied to tokens that are frequently generated.""" | |
presence_penalty: Optional[Any] = None | |
""" A penalty applied to tokens that are already present in the prompt.""" | |
count_penalty: Optional[Any] = None | |
"""A penalty applied to tokens based on their frequency | |
in the generated responses.""" | |
custom_model: Optional[str] = None | |
epoch: Optional[int] = None | |
class Config: | |
"""Configuration for this pydantic object.""" | |
allow_population_by_field_name = True | |
def _llm_type(self) -> str: | |
"""Return type of LLM.""" | |
return "ai21-llm" | |
def _default_params(self) -> Mapping[str, Any]: | |
base_params = { | |
"model": self.model, | |
"num_results": self.num_results, | |
"max_tokens": self.max_tokens, | |
"min_tokens": self.min_tokens, | |
"temperature": self.temperature, | |
"top_p": self.top_p, | |
"top_k_return": self.top_k_return, | |
} | |
if self.count_penalty is not None: | |
base_params["count_penalty"] = self.count_penalty.to_dict() | |
if self.custom_model is not None: | |
base_params["custom_model"] = self.custom_model | |
if self.epoch is not None: | |
base_params["epoch"] = self.epoch | |
if self.frequency_penalty is not None: | |
base_params["frequency_penalty"] = self.frequency_penalty.to_dict() | |
if self.presence_penalty is not None: | |
base_params["presence_penalty"] = self.presence_penalty.to_dict() | |
return base_params | |
def _build_params_for_request( | |
self, stop: Optional[List[str]] = None, **kwargs: Any | |
) -> Mapping[str, Any]: | |
params = {} | |
if stop is not None: | |
if "stop" in kwargs: | |
raise ValueError("stop is defined in both stop and kwargs") | |
params["stop_sequences"] = stop | |
return { | |
**self._default_params, | |
**params, | |
**kwargs, | |
} | |
def _generate( | |
self, | |
prompts: List[str], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> LLMResult: | |
generations: List[List[Generation]] = [] | |
token_count = 0 | |
params = self._build_params_for_request(stop=stop, **kwargs) | |
for prompt in prompts: | |
response = self._invoke_completion(prompt=prompt, **params) | |
generation = self._response_to_generation(response) | |
generations.append(generation) | |
token_count += self.client.count_tokens(prompt) | |
llm_output = {"token_count": token_count, "model_name": self.model} | |
return LLMResult(generations=generations, llm_output=llm_output) | |
async def _agenerate( | |
self, | |
prompts: List[str], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> LLMResult: | |
# Change implementation if integration natively supports async generation. | |
return await asyncio.get_running_loop().run_in_executor( | |
None, partial(self._generate, **kwargs), prompts, stop, run_manager | |
) | |
def _invoke_completion( | |
self, | |
prompt: str, | |
**kwargs: Any, | |
) -> CompletionsResponse: | |
return self.client.completion.create( | |
prompt=prompt, | |
**kwargs, | |
) | |
def _response_to_generation( | |
self, response: CompletionsResponse | |
) -> List[Generation]: | |
return [ | |
Generation( | |
text=completion.data.text, | |
generation_info=completion.to_dict(), | |
) | |
for completion in response.completions | |
] | |