File size: 4,135 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from typing import Any, Dict, List, Optional, Union

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, LLMResult
from langchain_core.pydantic_v1 import Field, root_validator


class CTranslate2(BaseLLM):
    """CTranslate2 language model."""

    model_path: str = ""
    """Path to the CTranslate2 model directory."""

    tokenizer_name: str = ""
    """Name of the original Hugging Face model needed to load the proper tokenizer."""

    device: str = "cpu"
    """Device to use (possible values are: cpu, cuda, auto)."""

    device_index: Union[int, List[int]] = 0
    """Device IDs where to place this generator on."""

    compute_type: Union[str, Dict[str, str]] = "default"
    """
    Model computation type or a dictionary mapping a device name to the computation type
    (possible values are: default, auto, int8, int8_float32, int8_float16,
    int8_bfloat16, int16, float16, bfloat16, float32).
    """

    max_length: int = 512
    """Maximum generation length."""

    sampling_topk: int = 1
    """Randomly sample predictions from the top K candidates."""

    sampling_topp: float = 1
    """Keep the most probable tokens whose cumulative probability exceeds this value."""

    sampling_temperature: float = 1
    """Sampling temperature to generate more random samples."""

    client: Any  #: :meta private:

    tokenizer: Any  #: :meta private:

    ctranslate2_kwargs: Dict[str, Any] = Field(default_factory=dict)
    """
    Holds any model parameters valid for `ctranslate2.Generator` call not 
    explicitly specified.
    """

    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate that python package exists in environment."""

        try:
            import ctranslate2
        except ImportError:
            raise ImportError(
                "Could not import ctranslate2 python package. "
                "Please install it with `pip install ctranslate2`."
            )

        try:
            import transformers
        except ImportError:
            raise ImportError(
                "Could not import transformers python package. "
                "Please install it with `pip install transformers`."
            )

        values["client"] = ctranslate2.Generator(
            model_path=values["model_path"],
            device=values["device"],
            device_index=values["device_index"],
            compute_type=values["compute_type"],
            **values["ctranslate2_kwargs"],
        )

        values["tokenizer"] = transformers.AutoTokenizer.from_pretrained(
            values["tokenizer_name"]
        )

        return values

    @property
    def _default_params(self) -> Dict[str, Any]:
        """Get the default parameters."""
        return {
            "max_length": self.max_length,
            "sampling_topk": self.sampling_topk,
            "sampling_topp": self.sampling_topp,
            "sampling_temperature": self.sampling_temperature,
        }

    def _generate(
        self,
        prompts: List[str],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> LLMResult:
        # build sampling parameters
        params = {**self._default_params, **kwargs}

        # call the model
        encoded_prompts = self.tokenizer(prompts)["input_ids"]
        tokenized_prompts = [
            self.tokenizer.convert_ids_to_tokens(encoded_prompt)
            for encoded_prompt in encoded_prompts
        ]

        results = self.client.generate_batch(tokenized_prompts, **params)

        sequences = [result.sequences_ids[0] for result in results]
        decoded_sequences = [self.tokenizer.decode(seq) for seq in sequences]

        generations = []
        for text in decoded_sequences:
            generations.append([Generation(text=text)])

        return LLMResult(generations=generations)

    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
        return "ctranslate2"