Spaces:
Runtime error
Runtime error
from typing import Any, Dict, List, Optional, Union | |
from langchain_core.callbacks import CallbackManagerForLLMRun | |
from langchain_core.language_models.llms import BaseLLM | |
from langchain_core.outputs import Generation, LLMResult | |
from langchain_core.pydantic_v1 import Field, root_validator | |
class CTranslate2(BaseLLM): | |
"""CTranslate2 language model.""" | |
model_path: str = "" | |
"""Path to the CTranslate2 model directory.""" | |
tokenizer_name: str = "" | |
"""Name of the original Hugging Face model needed to load the proper tokenizer.""" | |
device: str = "cpu" | |
"""Device to use (possible values are: cpu, cuda, auto).""" | |
device_index: Union[int, List[int]] = 0 | |
"""Device IDs where to place this generator on.""" | |
compute_type: Union[str, Dict[str, str]] = "default" | |
""" | |
Model computation type or a dictionary mapping a device name to the computation type | |
(possible values are: default, auto, int8, int8_float32, int8_float16, | |
int8_bfloat16, int16, float16, bfloat16, float32). | |
""" | |
max_length: int = 512 | |
"""Maximum generation length.""" | |
sampling_topk: int = 1 | |
"""Randomly sample predictions from the top K candidates.""" | |
sampling_topp: float = 1 | |
"""Keep the most probable tokens whose cumulative probability exceeds this value.""" | |
sampling_temperature: float = 1 | |
"""Sampling temperature to generate more random samples.""" | |
client: Any #: :meta private: | |
tokenizer: Any #: :meta private: | |
ctranslate2_kwargs: Dict[str, Any] = Field(default_factory=dict) | |
""" | |
Holds any model parameters valid for `ctranslate2.Generator` call not | |
explicitly specified. | |
""" | |
def validate_environment(cls, values: Dict) -> Dict: | |
"""Validate that python package exists in environment.""" | |
try: | |
import ctranslate2 | |
except ImportError: | |
raise ImportError( | |
"Could not import ctranslate2 python package. " | |
"Please install it with `pip install ctranslate2`." | |
) | |
try: | |
import transformers | |
except ImportError: | |
raise ImportError( | |
"Could not import transformers python package. " | |
"Please install it with `pip install transformers`." | |
) | |
values["client"] = ctranslate2.Generator( | |
model_path=values["model_path"], | |
device=values["device"], | |
device_index=values["device_index"], | |
compute_type=values["compute_type"], | |
**values["ctranslate2_kwargs"], | |
) | |
values["tokenizer"] = transformers.AutoTokenizer.from_pretrained( | |
values["tokenizer_name"] | |
) | |
return values | |
def _default_params(self) -> Dict[str, Any]: | |
"""Get the default parameters.""" | |
return { | |
"max_length": self.max_length, | |
"sampling_topk": self.sampling_topk, | |
"sampling_topp": self.sampling_topp, | |
"sampling_temperature": self.sampling_temperature, | |
} | |
def _generate( | |
self, | |
prompts: List[str], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> LLMResult: | |
# build sampling parameters | |
params = {**self._default_params, **kwargs} | |
# call the model | |
encoded_prompts = self.tokenizer(prompts)["input_ids"] | |
tokenized_prompts = [ | |
self.tokenizer.convert_ids_to_tokens(encoded_prompt) | |
for encoded_prompt in encoded_prompts | |
] | |
results = self.client.generate_batch(tokenized_prompts, **params) | |
sequences = [result.sequences_ids[0] for result in results] | |
decoded_sequences = [self.tokenizer.decode(seq) for seq in sequences] | |
generations = [] | |
for text in decoded_sequences: | |
generations.append([Generation(text=text)]) | |
return LLMResult(generations=generations) | |
def _llm_type(self) -> str: | |
"""Return type of llm.""" | |
return "ctranslate2" | |