Spaces:
Running
Running
File size: 5,238 Bytes
ed4d993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import asyncio
from functools import partial
from typing import (
Any,
List,
Mapping,
Optional,
)
from ai21.models import CompletionsResponse
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseLLM
from langchain_core.outputs import Generation, LLMResult
from langchain_ai21.ai21_base import AI21Base
class AI21LLM(BaseLLM, AI21Base):
"""AI21LLM large language models.
Example:
.. code-block:: python
from langchain_ai21 import AI21LLM
model = AI21LLM()
"""
model: str
"""Model type you wish to interact with.
You can view the options at https://github.com/AI21Labs/ai21-python?tab=readme-ov-file#model-types"""
num_results: int = 1
"""The number of responses to generate for a given prompt."""
max_tokens: int = 16
"""The maximum number of tokens to generate for each response."""
min_tokens: int = 0
"""The minimum number of tokens to generate for each response."""
temperature: float = 0.7
"""A value controlling the "creativity" of the model's responses."""
top_p: float = 1
"""A value controlling the diversity of the model's responses."""
top_k_return: int = 0
"""The number of top-scoring tokens to consider for each generation step."""
frequency_penalty: Optional[Any] = None
"""A penalty applied to tokens that are frequently generated."""
presence_penalty: Optional[Any] = None
""" A penalty applied to tokens that are already present in the prompt."""
count_penalty: Optional[Any] = None
"""A penalty applied to tokens based on their frequency
in the generated responses."""
custom_model: Optional[str] = None
epoch: Optional[int] = None
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
@property
def _llm_type(self) -> str:
"""Return type of LLM."""
return "ai21-llm"
@property
def _default_params(self) -> Mapping[str, Any]:
base_params = {
"model": self.model,
"num_results": self.num_results,
"max_tokens": self.max_tokens,
"min_tokens": self.min_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k_return": self.top_k_return,
}
if self.count_penalty is not None:
base_params["count_penalty"] = self.count_penalty.to_dict()
if self.custom_model is not None:
base_params["custom_model"] = self.custom_model
if self.epoch is not None:
base_params["epoch"] = self.epoch
if self.frequency_penalty is not None:
base_params["frequency_penalty"] = self.frequency_penalty.to_dict()
if self.presence_penalty is not None:
base_params["presence_penalty"] = self.presence_penalty.to_dict()
return base_params
def _build_params_for_request(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Mapping[str, Any]:
params = {}
if stop is not None:
if "stop" in kwargs:
raise ValueError("stop is defined in both stop and kwargs")
params["stop_sequences"] = stop
return {
**self._default_params,
**params,
**kwargs,
}
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
generations: List[List[Generation]] = []
token_count = 0
params = self._build_params_for_request(stop=stop, **kwargs)
for prompt in prompts:
response = self._invoke_completion(prompt=prompt, **params)
generation = self._response_to_generation(response)
generations.append(generation)
token_count += self.client.count_tokens(prompt)
llm_output = {"token_count": token_count, "model_name": self.model}
return LLMResult(generations=generations, llm_output=llm_output)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
# Change implementation if integration natively supports async generation.
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._generate, **kwargs), prompts, stop, run_manager
)
def _invoke_completion(
self,
prompt: str,
**kwargs: Any,
) -> CompletionsResponse:
return self.client.completion.create(
prompt=prompt,
**kwargs,
)
def _response_to_generation(
self, response: CompletionsResponse
) -> List[Generation]:
return [
Generation(
text=completion.data.text,
generation_info=completion.to_dict(),
)
for completion in response.completions
]
|