|
import abc |
|
import asyncio |
|
import dataclasses |
|
import json |
|
import logging |
|
import os |
|
import re |
|
import sys |
|
import time |
|
import uuid |
|
from collections import Counter |
|
from typing import ( |
|
Any, |
|
Dict, |
|
Iterable, |
|
List, |
|
Literal, |
|
Mapping, |
|
Optional, |
|
Sequence, |
|
Tuple, |
|
TypedDict, |
|
Union, |
|
) |
|
|
|
from datasets import DatasetDict |
|
from tqdm import tqdm, trange |
|
from tqdm.asyncio import tqdm_asyncio |
|
|
|
from .artifact import Artifact |
|
from .dataclass import InternalField, NonPositionalField |
|
from .deprecation_utils import deprecation |
|
from .error_utils import UnitxtError |
|
from .image_operators import EncodeImageToString, data_url_to_image, extract_images |
|
from .logging_utils import get_logger |
|
from .operator import PackageRequirementsMixin |
|
from .operators import ArtifactFetcherMixin |
|
from .settings_utils import get_constants, get_settings |
|
from .type_utils import isoftype |
|
|
|
constants = get_constants() |
|
settings = get_settings() |
|
logger = get_logger() |
|
|
|
|
|
class StandardAPIParamsMixin(Artifact): |
|
model: str |
|
frequency_penalty: Optional[float] = None |
|
presence_penalty: Optional[float] = None |
|
max_tokens: Optional[int] = None |
|
seed: Optional[int] = None |
|
stop: Union[Optional[str], List[str]] = None |
|
temperature: Optional[float] = None |
|
top_p: Optional[float] = None |
|
top_logprobs: Optional[int] = None |
|
logit_bias: Optional[Dict[str, int]] = None |
|
logprobs: Optional[bool] = None |
|
n: Optional[int] = None |
|
parallel_tool_calls: Optional[bool] = None |
|
service_tier: Optional[Literal["auto", "default"]] = None |
|
|
|
|
|
def get_model_and_label_id(model_name, label): |
|
model_id = model_name.split("/")[-1].replace("-", "_").replace(".", ",").lower() |
|
return f"{model_id}_{label}" |
|
|
|
|
|
@dataclasses.dataclass |
|
class TextGenerationInferenceOutput: |
|
"""Contains the prediction results and metadata for the inference. |
|
|
|
Args: |
|
prediction (Union[str, List[Dict[str, Any]]]): If this is the result of an _infer call, the string predicted by the model. |
|
If this is the results of an _infer_log_probs call, a list of dictionaries. The i'th dictionary represents |
|
the i'th token in the response. The entry "top_tokens" in the dictionary holds a sorted list of the top tokens |
|
for this position and their probabilities. |
|
For example: [ {.. "top_tokens": [ {"text": "a", 'logprob': }, {"text": "b", 'logprob': } ....]}, |
|
{.. "top_tokens": [ {"text": "c", 'logprob': }, {"text": "d", 'logprob': } ....]} |
|
] |
|
|
|
input_tokens (int) : number of input tokens to the model. |
|
output_tokens (int) : number of output tokens to the model. |
|
stop_reason (str): stop reason for text generation, for example "eos" (end of string). |
|
seed (int): seed used by the model during generation. |
|
input_text (str): input to the model. |
|
model_name (str): the model_name as kept in the InferenceEngine. |
|
inference_type (str): The label stating the type of the InferenceEngine. |
|
""" |
|
|
|
prediction: Union[str, List[Dict[str, Any]]] |
|
input_tokens: Optional[int] = None |
|
output_tokens: Optional[int] = None |
|
stop_reason: Optional[str] = None |
|
seed: Optional[int] = None |
|
input_text: Optional[str] = None |
|
model_name: Optional[str] = None |
|
inference_type: Optional[str] = None |
|
|
|
|
|
class InferenceEngine(Artifact): |
|
"""Abstract base class for inference.""" |
|
|
|
@abc.abstractmethod |
|
def _infer( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[str], List[TextGenerationInferenceOutput]]: |
|
"""Perform inference on the input dataset. |
|
|
|
If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the string. |
|
return_meta_data is only supported for some InferenceEngines. |
|
predictions. |
|
""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def prepare_engine(self): |
|
"""Perform inference on the input dataset.""" |
|
pass |
|
|
|
def prepare(self): |
|
if not settings.mock_inference_mode: |
|
super().prepare() |
|
self.prepare_engine() |
|
|
|
def infer( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[str], List[TextGenerationInferenceOutput]]: |
|
"""Verifies instances of a dataset and perform inference on the input dataset. |
|
|
|
If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the string |
|
predictions. |
|
""" |
|
if return_meta_data and not hasattr(self, "get_return_object"): |
|
raise NotImplementedError( |
|
f"Inference engine {self.__class__.__name__} does not support return_meta_data as it " |
|
f"does not contain a 'get_return_object' method. Please set return_meta_data=False." |
|
) |
|
|
|
[self.verify_instance(instance) for instance in dataset] |
|
if settings.mock_inference_mode: |
|
return self._mock_infer(dataset) |
|
return self._infer(dataset, return_meta_data) |
|
|
|
def _mock_infer( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
) -> Union[List[str], List[TextGenerationInferenceOutput]]: |
|
return [str(instance["source"]) for instance in dataset] |
|
|
|
def get_engine_id(self): |
|
raise NotImplementedError() |
|
|
|
@deprecation(version="2.0.0") |
|
def _set_inference_parameters(self): |
|
"""Sets inference parameters of an instance based on 'parameters' attribute (if given).""" |
|
if hasattr(self, "parameters") and self.parameters is not None: |
|
get_logger().warning( |
|
f"The 'parameters' attribute of '{self.get_pretty_print_name()}' " |
|
f"is deprecated. Please pass inference parameters directly to the " |
|
f"inference engine instance instead." |
|
) |
|
|
|
for param, param_dict_val in self.parameters.to_dict( |
|
[self.parameters] |
|
).items(): |
|
param_inst_val = getattr(self, param) |
|
if param_inst_val is None: |
|
setattr(self, param, param_dict_val) |
|
|
|
def get_model_details(self) -> Dict: |
|
"""Might not be possible to implement for all inference engines. Returns an empty dict by default.""" |
|
return {} |
|
|
|
def verify_not_chat_api(self, dataset): |
|
if isinstance(dataset[0]["source"], list): |
|
raise NotImplementedError( |
|
f"Inference engine {self.__class__.__name__} does not support chat api format." |
|
) |
|
|
|
def to_messages(self, instance): |
|
if isinstance(instance["source"], list): |
|
return instance["source"] |
|
return [ |
|
{ |
|
"role": "user", |
|
"content": instance["source"], |
|
} |
|
] |
|
|
|
|
|
class LogProbInferenceEngine(abc.ABC, Artifact): |
|
"""Abstract base class for inference with log probs.""" |
|
|
|
@abc.abstractmethod |
|
def _infer_log_probs( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]: |
|
"""Perform inference on the input dataset that returns log probs. |
|
|
|
If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the logprob dicts. |
|
return_meta_data is only supported for some InferenceEngines. |
|
predictions. |
|
""" |
|
pass |
|
|
|
def infer_log_probs( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]: |
|
"""Verifies instances of a dataset and performs inference that returns log probabilities of top tokens. |
|
|
|
For each instance , generates a list of top tokens per position. |
|
[ "top_tokens": [ { "text": ..., "logprob": ...} , ... ] |
|
If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns the list of the logprob dicts. |
|
return_meta_data is only supported for some InferenceEngines. |
|
""" |
|
if return_meta_data and not hasattr(self, "get_return_object"): |
|
raise NotImplementedError( |
|
f"Inference engine {self.__class__.__name__} does not support return_meta_data as it " |
|
f"does not contain a 'get_return_object' method. Please set return_meta_data=False." |
|
) |
|
|
|
[self.verify_instance(instance) for instance in dataset] |
|
return self._infer_log_probs(dataset, return_meta_data) |
|
|
|
|
|
class LazyLoadMixin(Artifact): |
|
lazy_load: bool = NonPositionalField(default=False) |
|
|
|
@abc.abstractmethod |
|
def _is_loaded(self): |
|
pass |
|
|
|
|
|
class HFGenerationParamsMixin(Artifact): |
|
max_new_tokens: int |
|
do_sample: bool = False |
|
temperature: Optional[float] = None |
|
top_p: Optional[float] = None |
|
top_k: Optional[int] = None |
|
num_beams: Optional[int] = None |
|
repetition_penalty: Optional[float] = None |
|
pad_token_id: Optional[int] = None |
|
eos_token_id: Optional[int] = None |
|
|
|
|
|
class HFInferenceEngineBase( |
|
InferenceEngine, |
|
LogProbInferenceEngine, |
|
PackageRequirementsMixin, |
|
LazyLoadMixin, |
|
HFGenerationParamsMixin, |
|
): |
|
model_name: str |
|
label: str |
|
|
|
n_top_tokens: int = 5 |
|
|
|
device: Any = None |
|
device_map: Any = None |
|
|
|
use_fast_tokenizer: bool = True |
|
low_cpu_mem_usage: bool = True |
|
torch_dtype: str = "torch.float16" |
|
|
|
model: Any = InternalField(default=None, name="Inference object") |
|
processor: Any = InternalField(default=None, name="Input processor (tokenizer)") |
|
|
|
_requirements_list = { |
|
"transformers": "Install huggingface package using 'pip install --upgrade transformers", |
|
"torch": "Install torch, go on PyTorch website for mode details.", |
|
"accelerate": "pip install accelerate", |
|
} |
|
|
|
def _is_loaded(self): |
|
return hasattr(self, "model") and self.model is not None |
|
|
|
def _set_inference_device(self): |
|
if self.device is not None and self.device_map is not None: |
|
raise ValueError( |
|
f"You must specify either 'device' or 'device_map', however both " |
|
f"were given: 'device={self.device}', 'device_map={self.device_map}'." |
|
) |
|
|
|
if self.device is None and self.device_map is None: |
|
import torch |
|
|
|
self.device = torch.device( |
|
"mps" |
|
if torch.backends.mps.is_available() |
|
else 0 |
|
if torch.cuda.is_available() |
|
else "cpu" |
|
) |
|
|
|
@abc.abstractmethod |
|
def _init_processor(self): |
|
raise NotImplementedError |
|
|
|
@abc.abstractmethod |
|
def _init_model(self): |
|
raise NotImplementedError |
|
|
|
def _get_torch_dtype(self): |
|
import torch |
|
|
|
if not isinstance(self.torch_dtype, str) or not self.torch_dtype.startswith( |
|
"torch." |
|
): |
|
raise ValueError( |
|
f"'torch_dtype' must be a string representing torch data " |
|
f"type used for inference. The name should be an absolute " |
|
f"import, for example: 'torch.float16'. However, " |
|
f"'{self.torch_dtype}' was given instead." |
|
) |
|
|
|
try: |
|
dtype = eval(self.torch_dtype) |
|
except (AttributeError, TypeError) as e: |
|
raise ValueError( |
|
f"Incorrect value of 'torch_dtype' was given: '{self.torch_dtype}'." |
|
) from e |
|
|
|
if not isinstance(dtype, torch.dtype): |
|
raise ValueError( |
|
f"'torch_dtype' must be an instance of 'torch.dtype', however, " |
|
f"'{dtype}' is an instance of '{type(dtype)}'." |
|
) |
|
|
|
return dtype |
|
|
|
def _prepare_engine(self): |
|
self._set_inference_device() |
|
self._init_processor() |
|
self._init_model() |
|
|
|
def prepare_engine(self): |
|
if not self.lazy_load: |
|
self._prepare_engine() |
|
|
|
def get_engine_id(self): |
|
return get_model_and_label_id(self.model_name, self.label) |
|
|
|
def decode_tokens(self, tokens: Sequence, inp_length: int) -> List[str]: |
|
return [ |
|
self.processor.decode(token, skip_special_tokens=True) |
|
for token in tokens[inp_length:] |
|
] |
|
|
|
@staticmethod |
|
def create_string_from_tokens(string_tokens: List[str]) -> str: |
|
return "".join(token for token in string_tokens) |
|
|
|
def make_predictions(self, prepared_inputs: Mapping) -> Mapping: |
|
return self.model.generate( |
|
**prepared_inputs, |
|
**self.to_dict([HFGenerationParamsMixin], keep_empty=False), |
|
output_scores=True, |
|
return_dict_in_generate=True, |
|
) |
|
|
|
def compute_transition_scores( |
|
self, sequences: Sequence, scores: Sequence, beam_indices: Optional[int] |
|
) -> Sequence: |
|
|
|
|
|
return self.model.compute_transition_scores( |
|
sequences, |
|
scores, |
|
normalize_logits=True, |
|
beam_indices=beam_indices, |
|
) |
|
|
|
def get_logprobs( |
|
self, predictions: Mapping, string_tokens: List[List[str]] |
|
) -> List[List[Dict[str, Any]]]: |
|
beam_indices = ( |
|
predictions.beam_indices |
|
if self.num_beams is not None and self.num_beams > 1 |
|
else None |
|
) |
|
|
|
transition_scores = self.compute_transition_scores( |
|
sequences=predictions.sequences, |
|
scores=predictions.scores, |
|
beam_indices=beam_indices, |
|
) |
|
|
|
logprobs: List[List[Dict[str, Any]]] = [] |
|
|
|
for sample_no, sample_scores in enumerate(transition_scores.detach().cpu()): |
|
sample_logprobs: List[Dict[str, Any]] = [] |
|
|
|
for n, score in enumerate(sample_scores): |
|
sample_logprobs.append( |
|
{ |
|
"text": string_tokens[sample_no][n], |
|
"logprob": float(score.cpu()), |
|
"top_tokens": [ |
|
{ |
|
"text": self.processor.decode(idx), |
|
"logprob": float( |
|
predictions.scores[n][sample_no][idx].cpu() |
|
), |
|
} |
|
for idx in predictions.scores[n][sample_no].argsort( |
|
dim=0, descending=True |
|
)[: self.n_top_tokens] |
|
], |
|
} |
|
) |
|
|
|
logprobs.append(sample_logprobs) |
|
|
|
return logprobs |
|
|
|
@abc.abstractmethod |
|
def prepare_inputs(self, data: Iterable) -> Mapping: |
|
raise NotImplementedError |
|
|
|
def get_return_object( |
|
self, |
|
output: Union[str, List[Dict[str, Any]]], |
|
output_tokens: Optional[int], |
|
inp: Optional[str], |
|
inp_tokens: Optional[int], |
|
return_meta_data: bool, |
|
) -> Union[str, List[Dict[str, Any]], TextGenerationInferenceOutput]: |
|
if return_meta_data: |
|
return TextGenerationInferenceOutput( |
|
prediction=output, |
|
output_tokens=output_tokens if output_tokens is not None else None, |
|
input_text=inp, |
|
input_tokens=inp_tokens if inp_tokens is not None else None, |
|
model_name=self.model_name, |
|
inference_type=self.label, |
|
) |
|
return output |
|
|
|
def infer( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[str], List[TextGenerationInferenceOutput]]: |
|
if not self._is_loaded(): |
|
self._prepare_engine() |
|
return super().infer(dataset, return_meta_data) |
|
|
|
@abc.abstractmethod |
|
def _infer( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[str], List[TextGenerationInferenceOutput]]: |
|
raise NotImplementedError |
|
|
|
def infer_log_probs( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]: |
|
if not self._is_loaded(): |
|
self._prepare_engine() |
|
return super().infer_log_probs(dataset, return_meta_data) |
|
|
|
@abc.abstractmethod |
|
def _infer_log_probs( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]: |
|
raise NotImplementedError |
|
|
|
|
|
class HFAutoModelInferenceEngine(HFInferenceEngineBase): |
|
label: str = "hf_auto_model" |
|
|
|
def _init_processor(self): |
|
from transformers import AutoTokenizer |
|
|
|
self.processor = AutoTokenizer.from_pretrained( |
|
pretrained_model_name_or_path=self.model_name, |
|
use_fast=self.use_fast_tokenizer, |
|
padding=True, |
|
truncation=True, |
|
) |
|
|
|
def _init_model(self): |
|
from transformers import ( |
|
AutoConfig, |
|
AutoModelForCausalLM, |
|
AutoModelForSeq2SeqLM, |
|
) |
|
|
|
model_class = ( |
|
AutoModelForSeq2SeqLM |
|
if AutoConfig.from_pretrained(self.model_name).is_encoder_decoder |
|
else AutoModelForCausalLM |
|
) |
|
|
|
self.model = model_class.from_pretrained( |
|
pretrained_model_name_or_path=self.model_name, |
|
trust_remote_code=True, |
|
device_map=self.device_map, |
|
torch_dtype=self._get_torch_dtype(), |
|
) |
|
if self.device_map is None: |
|
self.model.to(self.device) |
|
|
|
def prepare_inputs(self, data: Iterable) -> Mapping: |
|
return self.processor( |
|
data, |
|
padding=True, |
|
truncation=True, |
|
return_tensors="pt", |
|
).to(self.device or self.device_map) |
|
|
|
def _infer_fn( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool, |
|
return_logprobs: bool, |
|
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]: |
|
tokenized_inputs = self.prepare_inputs( |
|
[instance["source"] for instance in dataset] |
|
) |
|
input_length = ( |
|
1 |
|
if self.model.config.is_encoder_decoder |
|
else tokenized_inputs.input_ids.shape[1] |
|
) |
|
|
|
predictions = self.make_predictions(tokenized_inputs) |
|
sequences = predictions.sequences |
|
|
|
string_tokens = [ |
|
self.decode_tokens(sequence, input_length) for sequence in sequences |
|
] |
|
|
|
final_outputs = ( |
|
self.get_logprobs(predictions, string_tokens) |
|
if return_logprobs |
|
else [self.create_string_from_tokens(strings) for strings in string_tokens] |
|
) |
|
|
|
return [ |
|
self.get_return_object( |
|
output=final_outputs[i], |
|
output_tokens=len(string_tokens[i]), |
|
inp=dataset[i]["source"], |
|
inp_tokens=len(tokenized_inputs.encodings[i].tokens) |
|
if tokenized_inputs.encodings is not None |
|
else None, |
|
return_meta_data=return_meta_data, |
|
) |
|
for i in range(len(sequences)) |
|
] |
|
|
|
def _infer( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[str], List[TextGenerationInferenceOutput]]: |
|
self.verify_not_chat_api(dataset) |
|
return self._infer_fn(dataset, return_meta_data, False) |
|
|
|
def _infer_log_probs( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]: |
|
self.verify_not_chat_api(dataset) |
|
return self._infer_fn(dataset, return_meta_data, True) |
|
|
|
|
|
class HFLlavaInferenceEngine(HFInferenceEngineBase): |
|
lazy_load: bool = True |
|
label: str = "hf_lava" |
|
image_token: str = "<image>" |
|
|
|
def compute_transition_scores( |
|
self, sequences: Sequence, scores: Sequence, beam_indices: Optional[int] |
|
) -> Sequence: |
|
if not hasattr(self.model.config, "vocab_size"): |
|
self.model.config.vocab_size = self.model.vocab_size |
|
|
|
return super().compute_transition_scores(sequences, scores, beam_indices) |
|
|
|
def _init_processor(self): |
|
from transformers import AutoProcessor |
|
|
|
self.processor = AutoProcessor.from_pretrained(self.model_name) |
|
|
|
if not self.pad_token_id and hasattr(self.processor, "eos_token_id"): |
|
self.pad_token_id = self.processor.eos_token_id |
|
|
|
def _init_model(self): |
|
from transformers import LlavaForConditionalGeneration |
|
|
|
self.model = LlavaForConditionalGeneration.from_pretrained( |
|
self.model_name, |
|
torch_dtype=self._get_torch_dtype(), |
|
low_cpu_mem_usage=self.low_cpu_mem_usage, |
|
device_map=self.device_map, |
|
) |
|
if self.device_map is None: |
|
self.model.to(self.device) |
|
|
|
@staticmethod |
|
def _get_input(instance): |
|
assert isinstance(instance["source"], list), "Must use format=formats.chat_api" |
|
images = [] |
|
conversation = [] |
|
for turn in instance["source"]: |
|
if isinstance(turn["content"], list): |
|
for content in turn["content"]: |
|
if content["type"] == "image_url": |
|
content["type"] = "image" |
|
image_url = content.pop("image_url")["url"] |
|
image = data_url_to_image(image_url) |
|
images.append(image) |
|
conversation.append(turn) |
|
return conversation, images |
|
|
|
def prepare_inputs(self, data: Iterable) -> Mapping: |
|
conversation, images = self._get_input(data) |
|
|
|
if len(images) == 1: |
|
images = images[0] |
|
|
|
text = self.processor.apply_chat_template( |
|
conversation, add_generation_prompt=True |
|
) |
|
|
|
inputs: Mapping = self.processor( |
|
images=images, text=text, return_tensors="pt" |
|
).to(self.device or self.device_map, self._get_torch_dtype()) |
|
|
|
return inputs |
|
|
|
def _infer_fn( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool, |
|
return_logprobs: bool, |
|
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]: |
|
results = [] |
|
|
|
for instance in tqdm(dataset): |
|
processed_inputs = self.prepare_inputs(instance) |
|
input_len = len(processed_inputs["input_ids"][0]) |
|
|
|
predictions = self.make_predictions(processed_inputs) |
|
|
|
string_tokens = self.decode_tokens(predictions.sequences[0], input_len) |
|
|
|
final_outputs = ( |
|
self.get_logprobs(predictions, [string_tokens])[0] |
|
if return_logprobs |
|
else self.create_string_from_tokens(string_tokens) |
|
) |
|
|
|
results.append( |
|
self.get_return_object( |
|
output=final_outputs, |
|
output_tokens=len(string_tokens), |
|
inp=instance["source"], |
|
inp_tokens=None, |
|
return_meta_data=return_meta_data, |
|
) |
|
) |
|
|
|
return results |
|
|
|
def _infer( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[str], List[TextGenerationInferenceOutput]]: |
|
return self._infer_fn(dataset, return_meta_data, False) |
|
|
|
def _infer_log_probs( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]: |
|
return self._infer_fn(dataset, return_meta_data, True) |
|
|
|
|
|
class HFPeftInferenceEngine(HFAutoModelInferenceEngine): |
|
label: str = "hf_peft_auto_model" |
|
|
|
peft_config: Any = InternalField( |
|
default=None, |
|
name="PEFT config read from the directory or the Hub repository " |
|
"id specified in the 'model_name'.", |
|
) |
|
|
|
_requirements_list = { |
|
"transformers": "Install huggingface package using 'pip install --upgrade transformers", |
|
"torch": "Install torch, go on PyTorch website for mode details.", |
|
"accelerate": "pip install accelerate", |
|
"peft": "Install 'peft' package using: 'pip install peft'.", |
|
} |
|
|
|
def _prepare_engine(self): |
|
self._read_peft_config() |
|
super()._prepare_engine() |
|
|
|
def _read_peft_config(self): |
|
from peft import PeftConfig |
|
|
|
try: |
|
config = PeftConfig.from_pretrained(self.model_name) |
|
assert isinstance(config.base_model_name_or_path, str) |
|
self.peft_config = config |
|
|
|
except ValueError as e: |
|
if "Can't find" in str(e): |
|
raise ValueError( |
|
f"Specified model '{self.model_name}' is not the PEFT model. " |
|
f"Use a regular instance of the `HFAutoModelInferenceEngine` " |
|
f"instead." |
|
) from e |
|
|
|
raise e |
|
|
|
def _init_processor(self): |
|
from transformers import AutoTokenizer |
|
|
|
self.processor = AutoTokenizer.from_pretrained( |
|
self.peft_config.base_model_name_or_path |
|
) |
|
|
|
def _init_model(self): |
|
from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM |
|
from transformers import AutoConfig |
|
|
|
model_class = ( |
|
AutoPeftModelForSeq2SeqLM |
|
if AutoConfig.from_pretrained(self.model_name).is_encoder_decoder |
|
else AutoPeftModelForCausalLM |
|
) |
|
|
|
self.model = model_class.from_pretrained( |
|
pretrained_model_name_or_path=self.peft_config.base_model_name_or_path, |
|
trust_remote_code=True, |
|
device_map=self.device_map, |
|
low_cpu_mem_usage=self.low_cpu_mem_usage, |
|
torch_dtype=self._get_torch_dtype(), |
|
) |
|
if self.device_map is None: |
|
self.model.to(self.device) |
|
|
|
|
|
@deprecation( |
|
version="2.0.0", msg=" Use non-pipeline-based 'HFInferenceEngine' instead." |
|
) |
|
class HFPipelineBasedInferenceEngine( |
|
InferenceEngine, PackageRequirementsMixin, LazyLoadMixin, HFGenerationParamsMixin |
|
): |
|
model_name: str |
|
label: str = "hf_pipeline_inference_engine" |
|
|
|
use_fast_tokenizer: bool = True |
|
use_fp16: bool = True |
|
load_in_8bit: bool = False |
|
|
|
task: Optional[str] = None |
|
|
|
device: Any = None |
|
device_map: Any = None |
|
|
|
pipe: Any = InternalField(default=None) |
|
|
|
_requirements_list = { |
|
"transformers": "Install huggingface package using 'pip install --upgrade transformers", |
|
"torch": "Install torch, go on PyTorch website for mode details.", |
|
"accelerate": "pip install accelerate", |
|
} |
|
|
|
def _is_loaded(self): |
|
return hasattr(self, "model") and self.model is not None |
|
|
|
def get_engine_id(self): |
|
return get_model_and_label_id(self.model_name, "hf_pipeline") |
|
|
|
def _define_task(self): |
|
from transformers import AutoConfig |
|
|
|
self.task = ( |
|
"text2text-generation" |
|
if AutoConfig.from_pretrained( |
|
self.model_name, trust_remote_code=True |
|
).is_encoder_decoder |
|
else "text-generation" |
|
) |
|
|
|
def _get_model_args(self) -> Dict[str, Any]: |
|
import torch |
|
from transformers import BitsAndBytesConfig |
|
|
|
args = {} |
|
|
|
if self.load_in_8bit: |
|
quantization_config = BitsAndBytesConfig(load_in_8bit=self.load_in_8bit) |
|
args["quantization_config"] = quantization_config |
|
elif self.use_fp16: |
|
if self.device == torch.device("mps"): |
|
args["torch_dtype"] = torch.float16 |
|
else: |
|
args["torch_dtype"] = torch.bfloat16 |
|
|
|
|
|
|
|
|
|
if torch.cuda.device_count() > 1: |
|
assert self.device == torch.device(0) |
|
args["device_map"] = "auto" |
|
else: |
|
if not self.load_in_8bit: |
|
args["device"] = self.device |
|
|
|
if self.task == "text-generation": |
|
args["return_full_text"] = False |
|
|
|
return args |
|
|
|
def _create_pipeline(self, model_args: Dict[str, Any]): |
|
from transformers import pipeline |
|
|
|
self.model = pipeline( |
|
model=self.model_name, |
|
task=self.task, |
|
use_fast=self.use_fast_tokenizer, |
|
trust_remote_code=True, |
|
**model_args, |
|
**self.to_dict( |
|
[HFGenerationParamsMixin], |
|
keep_empty=False, |
|
), |
|
) |
|
|
|
def _set_inference_device(self): |
|
if self.device is not None and self.device_map is not None: |
|
raise ValueError( |
|
f"You must specify either 'device' or 'device_map', however both " |
|
f"were given: 'device={self.device}', 'device_map={self.device_map}'." |
|
) |
|
|
|
if self.device is None and self.device_map is None: |
|
import torch |
|
|
|
self.device = torch.device( |
|
"mps" |
|
if torch.backends.mps.is_available() |
|
else 0 |
|
if torch.cuda.is_available() |
|
else "cpu" |
|
) |
|
|
|
def _prepare_engine(self): |
|
self._set_inference_device() |
|
if self.task is None: |
|
self._define_task() |
|
model_args = self._get_model_args() |
|
self._create_pipeline(model_args) |
|
|
|
def prepare_engine(self): |
|
if not self.lazy_load: |
|
self._prepare_engine() |
|
|
|
def _infer( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[str], List[TextGenerationInferenceOutput]]: |
|
if not self._is_loaded(): |
|
self._prepare_engine() |
|
|
|
outputs = self.model([instance["source"] for instance in dataset]) |
|
|
|
return [ |
|
self.get_return_object(output[0], instance["source"], return_meta_data) |
|
if isinstance(output, list) |
|
else self.get_return_object(output, instance["source"], return_meta_data) |
|
for output, instance in zip(outputs, dataset) |
|
] |
|
|
|
def get_return_object(self, output, inp, return_meta_data): |
|
if return_meta_data: |
|
return TextGenerationInferenceOutput( |
|
prediction=output["generated_text"], |
|
model_name=self.model_name, |
|
inference_type=self.label, |
|
input_text=inp, |
|
) |
|
return output["generated_text"] |
|
|
|
|
|
def mock_logprobs_default_value_factory() -> List[Dict[str, Any]]: |
|
return [ |
|
{ |
|
"logprob": -1, |
|
"text": "[[10]]", |
|
"top_tokens": [ |
|
{"logprob": -1, "text": "[[10]]"}, |
|
], |
|
} |
|
] |
|
|
|
|
|
class MockInferenceEngine(InferenceEngine, LogProbInferenceEngine): |
|
model_name: str |
|
default_inference_value: str = "[[10]]" |
|
default_inference_value_logprob: List[Dict[str, Any]] = dataclasses.field( |
|
default_factory=mock_logprobs_default_value_factory, |
|
) |
|
label: str = "mock_inference_engine" |
|
|
|
def get_engine_id(self): |
|
return get_model_and_label_id(self.model_name, "mock") |
|
|
|
def prepare_engine(self): |
|
return |
|
|
|
def _mock_infer( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
) -> Union[List[str], List[TextGenerationInferenceOutput]]: |
|
return [self.default_inference_value for _ in dataset] |
|
|
|
def _infer( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[str], List[TextGenerationInferenceOutput]]: |
|
return [ |
|
self.get_return_object( |
|
self.default_inference_value, instance, return_meta_data |
|
) |
|
for instance in dataset |
|
] |
|
|
|
def _infer_log_probs( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]: |
|
return [ |
|
self.get_return_object( |
|
self.default_inference_value_logprob, instance, return_meta_data |
|
) |
|
for instance in dataset |
|
] |
|
|
|
def get_return_object(self, predict_result, instance, return_meta_data): |
|
if return_meta_data: |
|
return TextGenerationInferenceOutput( |
|
prediction=predict_result, |
|
input_tokens=len(instance["source"]), |
|
output_tokens=len(predict_result), |
|
model_name=self.model_name, |
|
inference_type=self.label, |
|
input_text=instance["source"], |
|
seed=111, |
|
stop_reason="", |
|
) |
|
return predict_result |
|
|
|
|
|
class MockModeMixin(Artifact): |
|
mock_mode: bool = False |
|
|
|
|
|
class IbmGenAiInferenceEngineParamsMixin(Artifact): |
|
beam_width: Optional[int] = None |
|
decoding_method: Optional[Literal["greedy", "sample"]] = None |
|
include_stop_sequence: Optional[bool] = None |
|
length_penalty: Any = None |
|
max_new_tokens: Optional[int] = None |
|
min_new_tokens: Optional[int] = None |
|
random_seed: Optional[int] = None |
|
repetition_penalty: Optional[float] = None |
|
return_options: Any = None |
|
stop_sequences: Optional[List[str]] = None |
|
temperature: Optional[float] = None |
|
time_limit: Optional[int] = None |
|
top_k: Optional[int] = None |
|
top_p: Optional[float] = None |
|
truncate_input_tokens: Optional[int] = None |
|
typical_p: Optional[float] = None |
|
|
|
|
|
@deprecation(version="2.0.0", alternative=IbmGenAiInferenceEngineParamsMixin) |
|
class IbmGenAiInferenceEngineParams(Artifact): |
|
beam_width: Optional[int] = None |
|
decoding_method: Optional[Literal["greedy", "sample"]] = None |
|
include_stop_sequence: Optional[bool] = None |
|
length_penalty: Any = None |
|
max_new_tokens: Optional[int] = None |
|
min_new_tokens: Optional[int] = None |
|
random_seed: Optional[int] = None |
|
repetition_penalty: Optional[float] = None |
|
return_options: Any = None |
|
stop_sequences: Optional[List[str]] = None |
|
temperature: Optional[float] = None |
|
time_limit: Optional[int] = None |
|
top_k: Optional[int] = None |
|
top_p: Optional[float] = None |
|
truncate_input_tokens: Optional[int] = None |
|
typical_p: Optional[float] = None |
|
|
|
|
|
class GenericInferenceEngine( |
|
InferenceEngine, ArtifactFetcherMixin, LogProbInferenceEngine |
|
): |
|
default: Optional[str] = None |
|
|
|
def prepare_engine(self): |
|
if "UNITXT_INFERENCE_ENGINE" in os.environ: |
|
engine_reference = os.environ["UNITXT_INFERENCE_ENGINE"] |
|
else: |
|
assert self.default is not None, ( |
|
"GenericInferenceEngine could not be initialized" |
|
'\nThis is since both the "UNITXT_INFERENCE_ENGINE" environmental variable is not set and no default engine was not inputted.' |
|
"\nFor example, you can fix it by setting" |
|
"\nexport UNITXT_INFERENCE_ENGINE=engines.ibm_gen_ai.llama_3_70b_instruct" |
|
"\nto your ~/.bashrc" |
|
"\nor passing a similar required engine in the default argument" |
|
) |
|
engine_reference = self.default |
|
self.engine = self.get_artifact(engine_reference) |
|
|
|
def get_engine_id(self): |
|
|
|
if hasattr(self, "engine"): |
|
return f"generic_{self.engine.get_engine_id()}" |
|
return "generic_inference_engine" |
|
|
|
def _infer( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[str], List[TextGenerationInferenceOutput]]: |
|
return self.engine._infer(dataset) |
|
|
|
def _infer_log_probs( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[str], List[TextGenerationInferenceOutput]]: |
|
if not isinstance(self.engine, LogProbInferenceEngine): |
|
raise NotImplementedError( |
|
f"Error in infer: inference engine used by the GenericInferenceEngine" |
|
f"({self.engine.__class__.__name__}) does not support logprobs." |
|
) |
|
return self.engine._infer_log_probs(dataset) |
|
|
|
|
|
class OllamaInferenceEngine( |
|
InferenceEngine, StandardAPIParamsMixin, PackageRequirementsMixin |
|
): |
|
label: str = "ollama" |
|
_requirements_list = { |
|
"ollama": "Install ollama package using 'pip install --upgrade ollama" |
|
} |
|
data_classification_policy = ["public", "proprietary"] |
|
|
|
def get_engine_id(self): |
|
return get_model_and_label_id(self.model, self.label) |
|
|
|
def prepare_engine(self): |
|
pass |
|
|
|
def _infer( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[str], List[TextGenerationInferenceOutput]]: |
|
import ollama |
|
|
|
args = self.to_dict([StandardAPIParamsMixin]) |
|
|
|
results = [] |
|
|
|
for instance in dataset: |
|
messages = self.to_messages(instance) |
|
response = ollama.chat( |
|
model=self.model, |
|
messages=messages, |
|
**args, |
|
) |
|
results.append(response) |
|
|
|
return [element["message"]["content"] for element in results] |
|
|
|
|
|
class OptionSelectingByLogProbsInferenceEngine: |
|
"""OptionSelectingByLogProbsInferenceEngine inference engine is used to select an option based on the logprobs of an options list conditioned by a prompt. |
|
|
|
The inference engines that inherit from this class must implement `get_token_count` and `get_options_log_probs`. |
|
""" |
|
|
|
@abc.abstractmethod |
|
def get_token_count(self, dataset): |
|
"""Get the token count of the source key of each dict of the dataset. Add to each instance in the data a "token_count" field. |
|
|
|
Args: |
|
dataset (List[Dict[str, Any]]): A list of dictionaries, each representing a data instance. |
|
|
|
Returns: |
|
List[int]: The token count of the texts |
|
""" |
|
|
|
@abc.abstractmethod |
|
def get_options_log_probs(self, dataset): |
|
"""Get the token logprobs of the options of the key task_data.options of each dict of the dataset. |
|
|
|
Add to each instance in the data a "options_log_prob" field, which is a dict with str as key and a list of {text: str, logprob:float}. |
|
|
|
Args: |
|
dataset (List[Dict[str, Any]]): A list of dictionaries, each representing a data instance. |
|
|
|
Returns: |
|
List[int]: The token count of the texts |
|
""" |
|
|
|
def select(self, dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
|
"""Calculate most likely labels based on log probabilities for a set of fixed completions.""" |
|
dataset_with_token_counts = self.get_token_count(dataset) |
|
token_counts = [d["token_count"] for d in dataset_with_token_counts] |
|
|
|
|
|
dataset_with_options = [ |
|
{ |
|
"source": instance["source"] + option, |
|
"task_data": {"token_count": token_count}, |
|
} |
|
for instance, token_count in zip(dataset, token_counts) |
|
for option in instance["task_data"]["options"] |
|
] |
|
|
|
dataset_with_options_logprobs: list[ |
|
list[dict[str, float | str]] |
|
] = self.get_options_log_probs(dataset_with_options) |
|
|
|
dataset_iterator = iter(dataset_with_options_logprobs) |
|
|
|
for instance in dataset: |
|
tokens_with_logprob_list = [] |
|
|
|
for _ in instance["task_data"]["options"]: |
|
tokens_with_logprob = next(dataset_iterator)["prediction"] |
|
tokens_with_logprob_list.append(tokens_with_logprob) |
|
|
|
to_compare_indexes = list(range(len(instance["task_data"]["options"]))) |
|
|
|
|
|
for token_with_logprob_comp in zip(*tokens_with_logprob_list): |
|
tokens_comp = [t["text"] for t in token_with_logprob_comp] |
|
logprobs_comp = [t["logprob"] for t in token_with_logprob_comp] |
|
|
|
index_max = max( |
|
( |
|
(val, idx) |
|
for idx, val in enumerate(logprobs_comp) |
|
if idx in to_compare_indexes |
|
), |
|
key=lambda x: x[0], |
|
)[1] |
|
|
|
token_value_with_max_logprob = tokens_comp[index_max] |
|
|
|
count = tokens_comp.count(token_value_with_max_logprob) |
|
if count > 1: |
|
|
|
to_compare_indexes = [ |
|
index |
|
for index, token_value in enumerate(tokens_comp) |
|
if token_value == token_value_with_max_logprob |
|
] |
|
continue |
|
|
|
break |
|
|
|
if len(to_compare_indexes) > 1: |
|
|
|
|
|
index_max = to_compare_indexes[0] |
|
|
|
instance["prediction"] = instance["task_data"]["options"][index_max] |
|
return dataset |
|
|
|
|
|
class IbmGenAiInferenceEngine( |
|
InferenceEngine, |
|
IbmGenAiInferenceEngineParamsMixin, |
|
PackageRequirementsMixin, |
|
LogProbInferenceEngine, |
|
OptionSelectingByLogProbsInferenceEngine, |
|
): |
|
label: str = "ibm_genai" |
|
model_name: str |
|
_requirements_list = { |
|
"ibm-generative-ai": "Install ibm-genai package using 'pip install --upgrade ibm-generative-ai" |
|
} |
|
data_classification_policy = ["public", "proprietary"] |
|
parameters: Optional[IbmGenAiInferenceEngineParams] = None |
|
rate_limit: int = 10 |
|
|
|
def get_engine_id(self): |
|
return get_model_and_label_id(self.model_name, self.label) |
|
|
|
@staticmethod |
|
def _get_credentials(): |
|
from genai import Credentials |
|
|
|
api_key_env_var_name = "GENAI_KEY" |
|
api_key = os.environ.get(api_key_env_var_name) |
|
|
|
assert api_key is not None, ( |
|
f"Error while trying to run IbmGenAiInferenceEngine." |
|
f" Please set the environment param '{api_key_env_var_name}'." |
|
) |
|
|
|
return Credentials(api_key=api_key) |
|
|
|
def prepare_engine(self): |
|
self.check_missing_requirements() |
|
|
|
from genai import Client |
|
from genai.text.generation import CreateExecutionOptions |
|
|
|
credentials = self._get_credentials() |
|
self.client = Client(credentials=credentials) |
|
|
|
self.execution_options = CreateExecutionOptions( |
|
concurrency_limit=self.rate_limit |
|
) |
|
|
|
self._set_inference_parameters() |
|
|
|
def _infer( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[str], List[TextGenerationInferenceOutput]]: |
|
from genai.schema import TextGenerationParameters, TextGenerationResult |
|
|
|
self.verify_not_chat_api(dataset) |
|
|
|
genai_params = TextGenerationParameters( |
|
**self.to_dict([IbmGenAiInferenceEngineParamsMixin]) |
|
) |
|
|
|
responses = self.client.text.generation.create( |
|
model_id=self.model_name, |
|
inputs=[instance["source"] for instance in dataset], |
|
parameters=genai_params, |
|
execution_options=self.execution_options, |
|
) |
|
|
|
results = [] |
|
for response in responses: |
|
generation_result: TextGenerationResult = response.results[0] |
|
result = self.get_return_object( |
|
generation_result.generated_text, generation_result, return_meta_data |
|
) |
|
results.append(result) |
|
return results |
|
|
|
def _infer_log_probs( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]: |
|
from genai.schema import TextGenerationParameters, TextGenerationResult |
|
|
|
self.verify_not_chat_api(dataset) |
|
|
|
logprobs_return_options = { |
|
"generated_tokens": True, |
|
"input_text": False, |
|
"input_tokens": False, |
|
"token_logprobs": True, |
|
"token_ranks": True, |
|
"top_n_tokens": 5, |
|
} |
|
genai_params = self.to_dict( |
|
[IbmGenAiInferenceEngineParamsMixin], keep_empty=False |
|
) |
|
genai_params = {**genai_params, "return_options": logprobs_return_options} |
|
genai_params = TextGenerationParameters(**genai_params) |
|
predictions = self.client.text.generation.create( |
|
model_id=self.model_name, |
|
inputs=[instance["source"] for instance in dataset], |
|
parameters=genai_params, |
|
execution_options=self.execution_options, |
|
) |
|
|
|
predict_results = [] |
|
for prediction in predictions: |
|
result: TextGenerationResult = prediction.results[0] |
|
assert isinstance( |
|
result.generated_tokens, list |
|
), "result.generated_tokens should be a list" |
|
|
|
predict_result = [] |
|
for base_token in result.generated_tokens: |
|
res = {**base_token.__dict__, **base_token.model_extra} |
|
res["top_tokens"] = [ |
|
{"logprob": top_token.logprob, "text": top_token.text} |
|
for top_token in res["top_tokens"] |
|
] |
|
predict_result.append(res) |
|
final_results = self.get_return_object( |
|
predict_result, result, return_meta_data |
|
) |
|
predict_results.append(final_results) |
|
return predict_results |
|
|
|
def get_return_object(self, predict_result, result, return_meta_data): |
|
if return_meta_data: |
|
return TextGenerationInferenceOutput( |
|
prediction=predict_result, |
|
input_tokens=result.input_token_count, |
|
output_tokens=result.generated_token_count, |
|
model_name=self.model_name, |
|
inference_type=self.label, |
|
input_text=result.input_text, |
|
seed=self.random_seed, |
|
stop_reason=result.stop_reason, |
|
) |
|
return predict_result |
|
|
|
def get_model_details(self) -> Dict: |
|
from genai import ApiClient |
|
from genai.model import ModelService |
|
|
|
api_client = ApiClient(credentials=self._get_credentials()) |
|
model_info = ( |
|
ModelService(api_client=api_client).retrieve(id=self.model_name).result |
|
) |
|
return model_info.dict() |
|
|
|
def get_token_count(self, dataset): |
|
texts = [instance["source"] for instance in dataset] |
|
token_counts = list( |
|
tqdm( |
|
[ |
|
result.token_count |
|
for response in self.client.text.tokenization.create( |
|
model_id=self.model_name, |
|
input=texts, |
|
execution_options={"ordered": True}, |
|
) |
|
for result in response.results |
|
], |
|
desc="Tokenizing", |
|
total=len(texts), |
|
) |
|
) |
|
for i, token_count in enumerate(token_counts): |
|
dataset[i]["token_count"] = token_count |
|
return dataset |
|
|
|
def get_options_log_probs(self, dataset): |
|
"""Add to each instance in the data a "options_log_prob" field, which is a dict with str as key and a list of {text: str, logprob:float}.""" |
|
from genai.schema import TextGenerationParameters, TextGenerationReturnOptions |
|
|
|
texts = [x["source"] for x in dataset] |
|
|
|
responses = tqdm( |
|
self.client.text.generation.create( |
|
model_id=self.model_name, |
|
inputs=texts, |
|
execution_options={"ordered": True}, |
|
parameters=TextGenerationParameters( |
|
max_new_tokens=1, |
|
return_options=TextGenerationReturnOptions( |
|
input_tokens=True, token_logprobs=True |
|
), |
|
|
|
), |
|
), |
|
total=len(texts), |
|
desc="Completions", |
|
) |
|
|
|
scores = [ |
|
[ |
|
{"text": token.text, "logprob": token.logprob} |
|
for token in response.results[0].input_tokens |
|
] |
|
for response in responses |
|
] |
|
|
|
for instance, score in zip(dataset, scores): |
|
instance["prediction"] = score[instance["task_data"]["token_count"] - 1 :] |
|
return dataset |
|
|
|
|
|
class CredentialsOpenAi(TypedDict, total=False): |
|
api_key: str |
|
api_url: str |
|
|
|
|
|
class OpenAiInferenceEngineParamsMixin(Artifact): |
|
frequency_penalty: Optional[float] = None |
|
presence_penalty: Optional[float] = None |
|
max_tokens: Optional[int] = None |
|
seed: Optional[int] = None |
|
stop: Union[Optional[str], List[str]] = None |
|
temperature: Optional[float] = None |
|
top_p: Optional[float] = None |
|
top_logprobs: Optional[int] = 20 |
|
logit_bias: Optional[Dict[str, int]] = None |
|
logprobs: Optional[bool] = True |
|
n: Optional[int] = None |
|
parallel_tool_calls: Optional[bool] = None |
|
service_tier: Optional[Literal["auto", "default"]] = None |
|
|
|
|
|
@deprecation(version="2.0.0", alternative=OpenAiInferenceEngineParamsMixin) |
|
class OpenAiInferenceEngineParams(Artifact): |
|
frequency_penalty: Optional[float] = None |
|
presence_penalty: Optional[float] = None |
|
max_tokens: Optional[int] = None |
|
seed: Optional[int] = None |
|
stop: Union[Optional[str], List[str]] = None |
|
temperature: Optional[float] = None |
|
top_p: Optional[float] = None |
|
top_logprobs: Optional[int] = 20 |
|
logit_bias: Optional[Dict[str, int]] = None |
|
logprobs: Optional[bool] = True |
|
n: Optional[int] = None |
|
parallel_tool_calls: Optional[bool] = None |
|
service_tier: Optional[Literal["auto", "default"]] = None |
|
|
|
|
|
class OpenAiInferenceEngine( |
|
InferenceEngine, |
|
LogProbInferenceEngine, |
|
OpenAiInferenceEngineParamsMixin, |
|
PackageRequirementsMixin, |
|
): |
|
label: str = "openai" |
|
model_name: str |
|
_requirements_list = { |
|
"openai": "Install openai package using 'pip install --upgrade openai" |
|
} |
|
data_classification_policy = ["public"] |
|
parameters: Optional[OpenAiInferenceEngineParams] = None |
|
base_url: Optional[str] = None |
|
default_headers: Dict[str, str] = {} |
|
credentials: CredentialsOpenAi = {} |
|
|
|
def get_engine_id(self) -> str: |
|
return get_model_and_label_id(self.model_name, self.label) |
|
|
|
def _prepare_credentials(self) -> CredentialsOpenAi: |
|
api_key = self.credentials.get( |
|
"api_key", os.environ.get(f"{self.label.upper()}_API_KEY", None) |
|
) |
|
assert api_key, ( |
|
f"Error while trying to run {self.label}. " |
|
f"Please set the env variable: '{self.label.upper()}_API_KEY'" |
|
) |
|
|
|
api_url = self.credentials.get( |
|
"api_url", os.environ.get(f"{self.label.upper()}_API_URL", None) |
|
) |
|
|
|
return {"api_key": api_key, "api_url": api_url} |
|
|
|
def get_default_headers(self) -> Dict[str, str]: |
|
return self.default_headers |
|
|
|
def create_client(self): |
|
from openai import OpenAI |
|
|
|
self.credentials = self._prepare_credentials() |
|
return OpenAI( |
|
api_key=self.credentials["api_key"], |
|
base_url=self.base_url or self.credentials["api_url"], |
|
default_headers=self.get_default_headers(), |
|
) |
|
|
|
def prepare_engine(self): |
|
self.client = self.create_client() |
|
self._set_inference_parameters() |
|
|
|
def _get_completion_kwargs(self): |
|
return { |
|
k: v |
|
for k, v in self.to_dict([OpenAiInferenceEngineParamsMixin]).items() |
|
if v is not None |
|
} |
|
|
|
def _infer( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[str], List[TextGenerationInferenceOutput]]: |
|
outputs = [] |
|
for instance in tqdm(dataset, desc="Inferring with openAI API"): |
|
messages = self.to_messages(instance) |
|
response = self.client.chat.completions.create( |
|
messages=messages, |
|
model=self.model_name, |
|
**self._get_completion_kwargs(), |
|
) |
|
prediction = response.choices[0].message.content |
|
output = self.get_return_object(prediction, response, return_meta_data) |
|
|
|
outputs.append(output) |
|
|
|
return outputs |
|
|
|
def _infer_log_probs( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]: |
|
outputs = [] |
|
for instance in tqdm(dataset, desc="Inferring with openAI API"): |
|
response = self.client.chat.completions.create( |
|
messages=[ |
|
|
|
|
|
|
|
|
|
{ |
|
"role": "user", |
|
"content": instance["source"], |
|
} |
|
], |
|
model=self.model_name, |
|
**self._get_completion_kwargs(), |
|
) |
|
top_logprobs_response = response.choices[0].logprobs.content |
|
pred_output = [ |
|
{ |
|
"top_tokens": [ |
|
{"text": obj.token, "logprob": obj.logprob} |
|
for obj in generated_token.top_logprobs |
|
] |
|
} |
|
for generated_token in top_logprobs_response |
|
] |
|
output = self.get_return_object(pred_output, response, return_meta_data) |
|
outputs.append(output) |
|
return outputs |
|
|
|
def get_return_object(self, predict_result, response, return_meta_data): |
|
if return_meta_data: |
|
return TextGenerationInferenceOutput( |
|
prediction=predict_result, |
|
input_tokens=response.usage.prompt_tokens, |
|
output_tokens=response.usage.completion_tokens, |
|
model_name=self.model_name, |
|
inference_type=self.label, |
|
) |
|
return predict_result |
|
|
|
|
|
class VLLMRemoteInferenceEngine(OpenAiInferenceEngine): |
|
label: str = "vllm" |
|
|
|
|
|
class RITSInferenceEngine(OpenAiInferenceEngine): |
|
label: str = "rits" |
|
|
|
def get_default_headers(self): |
|
return {"RITS_API_KEY": self.credentials["api_key"]} |
|
|
|
def prepare_engine(self): |
|
base_url_template = "https://inference-3scale-apicast-production.apps.rits.fmaas.res.ibm.com/{}/v1" |
|
self.base_url = base_url_template.format(self._get_model_name_for_endpoint()) |
|
logger.info(f"Created RITS inference engine with endpoint: {self.base_url}") |
|
super().prepare_engine() |
|
|
|
def _get_model_name_for_endpoint(self): |
|
return ( |
|
self.model_name.split("/")[-1] |
|
.lower() |
|
.replace("v0.1", "v01") |
|
.replace("vision-", "") |
|
.replace(".", "-") |
|
) |
|
|
|
|
|
class TogetherAiInferenceEngineParamsMixin(Artifact): |
|
max_tokens: Optional[int] = None |
|
stop: Optional[List[str]] = None |
|
temperature: Optional[float] = None |
|
top_p: Optional[float] = None |
|
top_k: Optional[int] = None |
|
repetition_penalty: Optional[float] = None |
|
logprobs: Optional[int] = None |
|
echo: Optional[bool] = None |
|
n: Optional[int] = None |
|
min_p: Optional[float] = None |
|
presence_penalty: Optional[float] = None |
|
frequency_penalty: Optional[float] = None |
|
|
|
|
|
class TogetherAiInferenceEngine( |
|
InferenceEngine, TogetherAiInferenceEngineParamsMixin, PackageRequirementsMixin |
|
): |
|
label: str = "together" |
|
model_name: str |
|
_requirements_list = { |
|
"together": "Install together package using 'pip install --upgrade together" |
|
} |
|
data_classification_policy = ["public"] |
|
parameters: Optional[TogetherAiInferenceEngineParamsMixin] = None |
|
|
|
def get_engine_id(self): |
|
return get_model_and_label_id(self.model_name, self.label) |
|
|
|
def prepare_engine(self): |
|
from together import Together |
|
from together.types.models import ModelType |
|
|
|
api_key_env_var_name = "TOGETHER_API_KEY" |
|
api_key = os.environ.get(api_key_env_var_name) |
|
assert api_key is not None, ( |
|
f"Error while trying to run TogetherAiInferenceEngine." |
|
f" Please set the environment param '{api_key_env_var_name}'." |
|
) |
|
self.client = Together(api_key=api_key) |
|
self._set_inference_parameters() |
|
|
|
|
|
together_models = self.client.models.list() |
|
together_model_id_to_type = { |
|
together_model.id: together_model.type for together_model in together_models |
|
} |
|
model_type = together_model_id_to_type.get(self.model_name) |
|
assert model_type is not None, ( |
|
f"Could not find model {self.model_name} " "in Together AI model list" |
|
) |
|
assert model_type in [ModelType.CHAT, ModelType.LANGUAGE, ModelType.CODE], ( |
|
f"Together AI model type {model_type} is not supported; " |
|
"supported types are 'chat', 'language' and 'code'." |
|
) |
|
self.model_type = model_type |
|
|
|
def _get_infer_kwargs(self): |
|
return { |
|
k: v |
|
for k, v in self.to_dict([TogetherAiInferenceEngineParamsMixin]).items() |
|
if v is not None |
|
} |
|
|
|
def _infer_chat(self, instance: Dict[str, Any]) -> str: |
|
messages = self.to_messages(instance) |
|
response = self.client.chat.completions.create( |
|
model=self.model_name, |
|
messages=messages, |
|
**self._get_infer_kwargs(), |
|
) |
|
return response.choices[0].message.content |
|
|
|
def _infer_text(self, instance: Dict[str, Any]) -> str: |
|
response = self.client.completions.create( |
|
model=self.model_name, |
|
prompt=instance["source"], |
|
**self._get_infer_kwargs(), |
|
) |
|
return response.choices[0].text |
|
|
|
def _infer( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[str], List[TextGenerationInferenceOutput]]: |
|
from together.types.models import ModelType |
|
|
|
outputs = [] |
|
if self.model_type == ModelType.CHAT: |
|
for instance in tqdm(dataset, desc="Inferring with Together AI Chat API"): |
|
outputs.append(self._infer_chat(instance)) |
|
else: |
|
self.verify_not_chat_api(dataset) |
|
for instance in tqdm(dataset, desc="Inferring with Together AI Text API"): |
|
outputs.append(self._infer_text(instance)) |
|
return outputs |
|
|
|
|
|
@deprecation( |
|
version="2.0.0", |
|
msg=" You can specify inference parameters directly when initializing an inference engine.", |
|
) |
|
class WMLInferenceEngineParamsMixin(Artifact): |
|
decoding_method: Optional[Literal["greedy", "sample"]] = None |
|
length_penalty: Optional[Dict[str, Union[int, float]]] = None |
|
temperature: Optional[float] = None |
|
top_p: Optional[float] = None |
|
top_k: Optional[int] = None |
|
random_seed: Optional[int] = None |
|
repetition_penalty: Optional[float] = None |
|
min_new_tokens: Optional[int] = None |
|
max_new_tokens: Optional[int] = None |
|
stop_sequences: Optional[List[str]] = None |
|
time_limit: Optional[int] = None |
|
truncate_input_tokens: Optional[int] = None |
|
prompt_variables: Optional[Dict[str, Any]] = None |
|
return_options: Optional[Dict[str, bool]] = None |
|
|
|
|
|
@deprecation(version="2.0.0", alternative=WMLInferenceEngineParamsMixin) |
|
class WMLInferenceEngineParams(Artifact): |
|
decoding_method: Optional[Literal["greedy", "sample"]] = None |
|
length_penalty: Optional[Dict[str, Union[int, float]]] = None |
|
temperature: Optional[float] = None |
|
top_p: Optional[float] = None |
|
top_k: Optional[int] = None |
|
random_seed: Optional[int] = None |
|
repetition_penalty: Optional[float] = None |
|
min_new_tokens: Optional[int] = None |
|
max_new_tokens: Optional[int] = None |
|
stop_sequences: Optional[List[str]] = None |
|
time_limit: Optional[int] = None |
|
truncate_input_tokens: Optional[int] = None |
|
prompt_variables: Optional[Dict[str, Any]] = None |
|
return_options: Optional[Dict[str, bool]] = None |
|
|
|
|
|
class WMLGenerationParamsMixin(Artifact): |
|
decoding_method: Optional[Literal["greedy", "sample"]] = None |
|
length_penalty: Optional[Dict[str, Union[int, float]]] = None |
|
temperature: Optional[float] = None |
|
top_p: Optional[float] = None |
|
top_k: Optional[int] = None |
|
random_seed: Optional[int] = None |
|
repetition_penalty: Optional[float] = None |
|
min_new_tokens: Optional[int] = None |
|
max_new_tokens: Optional[int] = None |
|
stop_sequences: Optional[List[str]] = None |
|
time_limit: Optional[int] = None |
|
truncate_input_tokens: Optional[int] = None |
|
prompt_variables: Optional[Dict[str, Any]] = None |
|
return_options: Optional[Dict[str, bool]] = None |
|
|
|
|
|
class WMLChatParamsMixin(Artifact): |
|
frequency_penalty: Optional[float] = None |
|
top_logprobs: Optional[int] = 5 |
|
presence_penalty: Optional[float] = None |
|
response_format: Optional[Dict[str, Any]] = None |
|
temperature: Optional[float] = None |
|
max_tokens: Optional[int] = None |
|
time_limit: Optional[int] = None |
|
top_p: Optional[float] = None |
|
n: Optional[int] = None |
|
|
|
|
|
CredentialsWML = Dict[ |
|
Literal["url", "username", "password", "apikey", "project_id", "space_id"], str |
|
] |
|
|
|
|
|
class WMLInferenceEngineBase( |
|
InferenceEngine, |
|
PackageRequirementsMixin, |
|
LogProbInferenceEngine, |
|
OptionSelectingByLogProbsInferenceEngine, |
|
): |
|
"""Base for classes running inference using ibm-watsonx-ai. |
|
|
|
Attributes: |
|
credentials (Dict[str, str], optional): By default, it is created by a class |
|
instance which tries to retrieve proper environment variables |
|
("WML_URL", "WML_PROJECT_ID", "WML_SPACE_ID", "WML_APIKEY", "WML_USERNAME", "WML_PASSWORD"). |
|
However, a dictionary with the following keys: "url", "apikey", "project_id", "space_id", |
|
"username", "password". |
|
can be directly provided instead. |
|
model_name (str, optional): ID of a model to be used for inference. Mutually |
|
exclusive with 'deployment_id'. |
|
deployment_id (str, optional): Deployment ID of a tuned model to be used for |
|
inference. Mutually exclusive with 'model_name'. |
|
parameters (Union[WMLInferenceEngineParams, WMLGenerationParamsMixin, WMLChatParamsMixin], optional): |
|
Defines inference parameters and their values. Deprecated attribute, please pass respective |
|
parameters directly to the respective class instead. |
|
""" |
|
|
|
credentials: Optional[CredentialsWML] = None |
|
model_name: Optional[str] = None |
|
deployment_id: Optional[str] = None |
|
label: str = "wml" |
|
_requirements_list = { |
|
"ibm_watsonx_ai": "Install ibm-watsonx-ai package using 'pip install --upgrade ibm-watsonx-ai'. " |
|
"It is advised to have Python version >=3.10 installed, as at lower version this package " |
|
"may cause conflicts with other installed packages." |
|
} |
|
data_classification_policy = ["public", "proprietary"] |
|
parameters: Optional[ |
|
Union[WMLInferenceEngineParams, WMLGenerationParamsMixin, WMLChatParamsMixin] |
|
] = None |
|
|
|
_client: Any = InternalField(default=None, name="WML client") |
|
_model: Any = InternalField(default=None, name="WML model") |
|
|
|
def get_engine_id(self): |
|
return get_model_and_label_id(self.model_name or self.deployment_id, self.label) |
|
|
|
def verify(self): |
|
super().verify() |
|
|
|
assert ( |
|
self.model_name |
|
or self.deployment_id |
|
and not (self.model_name and self.deployment_id) |
|
), "Either 'model_name' or 'deployment_id' must be specified, but not both at the same time." |
|
|
|
def process_data_before_dump(self, data): |
|
if "credentials" in data: |
|
for key, value in data["credentials"].items(): |
|
if key != "url": |
|
data["credentials"][key] = "<hidden>" |
|
else: |
|
data["credentials"][key] = value |
|
return data |
|
|
|
def _initialize_wml_client(self): |
|
from ibm_watsonx_ai.client import APIClient |
|
|
|
if self.credentials is None: |
|
self.credentials = self._read_wml_credentials_from_env() |
|
self._verify_wml_credentials(self.credentials) |
|
|
|
client = APIClient(credentials=self.credentials) |
|
if "space_id" in self.credentials: |
|
client.set.default_space(self.credentials["space_id"]) |
|
else: |
|
client.set.default_project(self.credentials["project_id"]) |
|
return client |
|
|
|
@staticmethod |
|
def _read_wml_credentials_from_env() -> CredentialsWML: |
|
credentials: CredentialsWML = {} |
|
|
|
url = os.environ.get("WML_URL") |
|
assert url, ( |
|
"Error while trying to run 'WMLInferenceEngine'. " |
|
"Please set the env variable: 'WML_URL'" |
|
) |
|
credentials["url"] = url |
|
|
|
space_id = os.environ.get("WML_SPACE_ID") |
|
project_id = os.environ.get("WML_PROJECT_ID") |
|
if space_id and project_id: |
|
get_logger().warning( |
|
"Either 'WML_SPACE_ID' or 'WML_PROJECT_ID' need to be " |
|
"specified, however, both were found. 'WMLInferenceEngine' " |
|
"will use space by default. If it is not desired, then have " |
|
"only one of those defined in the env." |
|
) |
|
credentials["space_id"] = space_id |
|
elif project_id: |
|
credentials["project_id"] = project_id |
|
else: |
|
raise AssertionError( |
|
"Error while trying to run 'WMLInferenceEngine'. " |
|
"Please set either 'WML_SPACE_ID' or 'WML_PROJECT_ID' env " |
|
"variable." |
|
) |
|
|
|
apikey = os.environ.get("WML_APIKEY") |
|
username = os.environ.get("WML_USERNAME") |
|
password = os.environ.get("WML_PASSWORD") |
|
|
|
if apikey and username and password: |
|
get_logger().warning( |
|
"Either 'WML_APIKEY' or both 'WML_USERNAME' and 'WML_PASSWORD' " |
|
"need to be specified, however, all of them were found. " |
|
"'WMLInferenceEngine' will use api key only by default. If it is not " |
|
"desired, then have only one of those options defined in the env." |
|
) |
|
|
|
if apikey: |
|
credentials["apikey"] = apikey |
|
elif username and password: |
|
credentials["username"] = username |
|
credentials["password"] = password |
|
else: |
|
raise AssertionError( |
|
"Error while trying to run 'WMLInferenceEngine'. " |
|
"Please set either 'WML_APIKEY' or both 'WML_USERNAME' and " |
|
"'WML_PASSWORD' env variables." |
|
) |
|
|
|
return credentials |
|
|
|
@staticmethod |
|
def _verify_wml_credentials(credentials: CredentialsWML) -> None: |
|
assert isoftype(credentials, CredentialsWML), ( |
|
"WML credentials object must be a dictionary which may " |
|
"contain only the following keys: " |
|
"['url', 'apikey', 'username', 'password']." |
|
) |
|
|
|
assert credentials.get( |
|
"url" |
|
), "'url' is a mandatory key for WML credentials dict." |
|
assert "space_id" in credentials or "project_id" in credentials, ( |
|
"Either 'space_id' or 'project_id' must be provided " |
|
"as keys for WML credentials dict." |
|
) |
|
assert "apikey" in credentials or ( |
|
"username" in credentials and "password" in credentials |
|
), ( |
|
"Either 'apikey' or both 'username' and 'password' must be provided " |
|
"as keys for WML credentials dict." |
|
) |
|
|
|
def prepare_engine(self): |
|
self.check_missing_requirements() |
|
|
|
self._client = self._initialize_wml_client() |
|
|
|
self._set_inference_parameters() |
|
|
|
def _load_model(self): |
|
from ibm_watsonx_ai.foundation_models.inference import ModelInference |
|
|
|
self._model = ModelInference( |
|
model_id=self.model_name, |
|
deployment_id=self.deployment_id, |
|
api_client=self._client, |
|
) |
|
|
|
@abc.abstractmethod |
|
def _send_requests( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_logprobs: bool, |
|
return_meta_data: bool, |
|
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]: |
|
raise NotImplementedError( |
|
f"The class '{self.get_pretty_print_name()}' is an abstract class. " |
|
f"Please used either 'WMLInferenceEngineGeneration' or " |
|
f"'WMLInferenceEngineChat' instead, depending on your task." |
|
) |
|
|
|
def _infer( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[str], List[TextGenerationInferenceOutput]]: |
|
if self._model is None: |
|
self._load_model() |
|
|
|
return self._send_requests( |
|
dataset=dataset, |
|
return_logprobs=False, |
|
return_meta_data=return_meta_data, |
|
) |
|
|
|
def _infer_log_probs( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]: |
|
if self._model is None: |
|
self._load_model() |
|
|
|
return self._send_requests( |
|
dataset=dataset, |
|
return_logprobs=True, |
|
return_meta_data=return_meta_data, |
|
) |
|
|
|
@abc.abstractmethod |
|
def get_return_object(self, predict_result, result, input_text, return_meta_data): |
|
raise NotImplementedError |
|
|
|
def get_model_details(self) -> Dict: |
|
return self._model.get_details() |
|
|
|
def get_token_count(self, dataset): |
|
if self._model is None: |
|
self._load_model() |
|
|
|
texts = [instance["source"] for instance in dataset] |
|
|
|
for i in trange(len(texts), desc="Tokenizing"): |
|
response = self._model.tokenize(prompt=texts[i], return_tokens=True)[ |
|
"result" |
|
] |
|
dataset[i]["token_count"] = response["token_count"] |
|
|
|
return dataset |
|
|
|
def get_options_log_probs(self, dataset): |
|
"""Add to each instance in the data a "options_log_prob" field, which is a dict with str as key and a list of {text: str, logprob:float}.""" |
|
if self._model is None: |
|
self._load_model() |
|
|
|
texts = [x["source"] for x in dataset] |
|
|
|
responses = list( |
|
tqdm( |
|
self._model.generate( |
|
prompt=texts, |
|
params={ |
|
"decoding_method": "greedy", |
|
"max_new_tokens": 1, |
|
"return_options": { |
|
"input_tokens": True, |
|
"token_logprobs": True, |
|
}, |
|
}, |
|
), |
|
total=len(texts), |
|
desc="Completions", |
|
) |
|
) |
|
|
|
scores = [ |
|
[ |
|
{ |
|
"text": token["text"], |
|
"logprob": token["logprob"] if "logprob" in token else 1, |
|
} |
|
for token in response["results"][0]["input_tokens"] |
|
] |
|
for response in responses |
|
] |
|
|
|
for instance, score in zip(dataset, scores): |
|
instance["prediction"] = score[instance["task_data"]["token_count"] - 1 :] |
|
return dataset |
|
|
|
|
|
class WMLInferenceEngineGeneration(WMLInferenceEngineBase, WMLGenerationParamsMixin): |
|
"""Generates text for textual inputs. |
|
|
|
If you want to include images in your input, please use 'WMLInferenceEngineChat' instead. |
|
|
|
Attributes: |
|
concurrency_limit (int): Number of concurrent requests sent to a model. Default is 10, |
|
which is also the maximum value. |
|
|
|
Examples: |
|
from .api import load_dataset |
|
|
|
wml_credentials = { |
|
"url": "some_url", "project_id": "some_id", "api_key": "some_key" |
|
} |
|
model_name = "google/flan-t5-xxl" |
|
wml_inference = WMLInferenceEngineGeneration( |
|
credentials=wml_credentials, |
|
model_name=model_name, |
|
data_classification_policy=["public"], |
|
top_p=0.5, |
|
random_seed=123, |
|
) |
|
|
|
dataset = load_dataset( |
|
dataset_query="card=cards.argument_topic,template_card_index=0,loader_limit=5" |
|
) |
|
results = wml_inference.infer(dataset["test"]) |
|
""" |
|
|
|
concurrency_limit: int = 10 |
|
|
|
def verify(self): |
|
super().verify() |
|
|
|
assert ( |
|
isinstance(self.concurrency_limit, int) |
|
and 1 <= self.concurrency_limit <= 10 |
|
), ( |
|
f"'concurrency_limit' must be a positive integer not greater than 10. " |
|
f"However, '{self.concurrency_limit}' was given." |
|
) |
|
|
|
def _set_logprobs_params(self, params: Dict[str, Any]) -> Dict[str, Any]: |
|
user_return_options = params.pop("return_options", {}) |
|
|
|
|
|
logprobs_return_options = { |
|
"input_tokens": True, |
|
"generated_tokens": True, |
|
"token_logprobs": True, |
|
"top_n_tokens": user_return_options.get("top_n_tokens", 5), |
|
} |
|
|
|
for key, value in logprobs_return_options.items(): |
|
if key in user_return_options and user_return_options[key] != value: |
|
raise ValueError( |
|
f"'{key}={user_return_options[key]}' is not supported for the 'infer_log_probs' " |
|
f"method of {self.__class__.__name__}. For obtaining the logprobs of generated tokens " |
|
f"please use '{key}={value}'." |
|
) |
|
|
|
return { |
|
**params, |
|
"return_options": logprobs_return_options, |
|
} |
|
|
|
def _send_requests( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_logprobs: bool, |
|
return_meta_data: bool, |
|
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]: |
|
self.verify_not_chat_api(dataset) |
|
|
|
params = self.to_dict([WMLGenerationParamsMixin], keep_empty=False) |
|
|
|
if return_logprobs: |
|
generation_type = "generated_tokens" |
|
params = self._set_logprobs_params(params) |
|
else: |
|
generation_type = "generated_text" |
|
|
|
inputs: List[str] = [instance["source"] for instance in dataset] |
|
|
|
results = self._model.generate( |
|
prompt=inputs, |
|
params=params, |
|
concurrency_limit=self.concurrency_limit, |
|
) |
|
|
|
final_results = [] |
|
for result, inp in zip(results, inputs): |
|
result_metadata = result["results"][0] |
|
generated_content = result_metadata[generation_type] |
|
final_results.append( |
|
self.get_return_object( |
|
generated_content, result_metadata, inp, return_meta_data |
|
) |
|
) |
|
return final_results |
|
|
|
def get_return_object(self, predict_result, result, input_text, return_meta_data): |
|
if return_meta_data: |
|
return TextGenerationInferenceOutput( |
|
prediction=predict_result, |
|
input_tokens=result["input_token_count"], |
|
output_tokens=result["generated_token_count"], |
|
model_name=self.model_name or self.deployment_id, |
|
inference_type=self.label, |
|
stop_reason=result["stop_reason"], |
|
seed=self.random_seed, |
|
input_text=input_text, |
|
) |
|
return predict_result |
|
|
|
|
|
class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin): |
|
"""Creates chat session and returns a model's response. |
|
|
|
You can also include images in your inputs. If you use only textual input, it is |
|
recommended to use 'WMLInferenceEngineGeneration' instead as it is faster, and allows |
|
more parameters for text generation. |
|
|
|
You can provide either already formatted messages, or a raw dataset as an input. |
|
In case of the former, all passed images should be base64-encoded strings given as |
|
an 'image_url' within a message. Moreover, only one image per a list of messages |
|
may be sent. |
|
As for the latter, if there are multiple images per one instance, they will be sent |
|
separately with the same query. If that could possibly affect expected responses, |
|
concatenate images within an instance into a single image and adjust your query |
|
accordingly (if necessary). |
|
|
|
Attributes: |
|
image_encoder (EncodeImageToString, optional): operator which encodes images in |
|
given format to base64 strings required by service. You should specify it when |
|
you are using images in your inputs. |
|
|
|
Example: |
|
from .api import load_dataset |
|
from .image_operators |
|
|
|
image_encoder = EncodeImageToString(image_format="JPEG") |
|
|
|
wml_credentials = { |
|
"url": "some_url", "project_id": "some_id", "api_key": "some_key" |
|
} |
|
model_name = "meta-llama/llama-3-2-11b-vision-instruct" |
|
wml_inference = WMLInferenceEngineChat( |
|
credentials=wml_credentials, |
|
model_name=model_name, |
|
image_encoder=image_encoder, |
|
data_classification_policy=["public"], |
|
max_tokens=1024, |
|
) |
|
|
|
dataset = load_dataset( |
|
dataset_query="card=cards.doc_vqa.en,template=templates.qa.with_context.with_type,loader_limit=30" |
|
) |
|
results = wml_inference.infer(dataset["test"]) |
|
""" |
|
|
|
image_encoder: Optional[EncodeImageToString] = None |
|
|
|
@staticmethod |
|
def _extract_queries(instance: Dict[str, Any]) -> Tuple[Optional[str], List]: |
|
task_data = instance["task_data"] |
|
if isinstance(task_data, str): |
|
task_data = json.loads(task_data) |
|
question = task_data.get("question") |
|
|
|
images = [None] |
|
if "images" in instance["media"]: |
|
images = extract_images(instance["source"], instance) |
|
|
|
return question or instance["source"], images |
|
|
|
def _create_messages_from_instance( |
|
self, instance: Dict[str, Any] |
|
) -> List[List[Dict[str, Any]]]: |
|
"""Method creates chat messages to be sent to a watsonx.ai model based on a given instance from a dataset.""" |
|
text, images = self._extract_queries(instance) |
|
|
|
messages: List[List[Dict[str, Any]]] = [] |
|
base_message = { |
|
"role": "user", |
|
"content": [ |
|
{ |
|
"type": "text", |
|
"text": text, |
|
} |
|
], |
|
} |
|
|
|
|
|
|
|
for image in images: |
|
message = base_message.copy() |
|
|
|
if image is not None: |
|
encoded_image = image |
|
if not isinstance(encoded_image, str): |
|
if self.image_encoder is None: |
|
raise ValueError( |
|
"If sending image queries as well, and they are not " |
|
"already encoded to base64 strings, you must specify " |
|
"the 'image_encoder' to be used." |
|
) |
|
encoded_image = self.image_encoder.encode_image_to_base64(image) |
|
|
|
message["content"].append( |
|
{ |
|
"type": "image_url", |
|
"image_url": { |
|
"url": "data:image/jpeg;base64," + encoded_image, |
|
}, |
|
} |
|
) |
|
|
|
messages.append([message]) |
|
|
|
return messages |
|
|
|
@staticmethod |
|
def verify_messages(messages: List[Dict[str, Any]]): |
|
"""Method verifies if externally provided messages containing images are compatible with the format required by ibm-watsonx-ai.""" |
|
n_images = 0 |
|
for message in messages: |
|
if isinstance(message["content"], str): |
|
continue |
|
|
|
for content in message["content"]: |
|
if isinstance(content, dict): |
|
if "image" in content["type"] and content["type"] != "image_url": |
|
raise ValueError( |
|
f"ibm-watsonx-ai only supports sending images as base64-encoded " |
|
f"strings, which should be given as 'image_url' in a message. " |
|
f"However, '{content['type']}' was given." |
|
) |
|
|
|
if content["type"] == "image_url": |
|
n_images += 1 |
|
if n_images > 1: |
|
raise ValueError( |
|
"ibm-watsonx-ai only supports sending one image per a list " |
|
"of messages." |
|
) |
|
|
|
def to_messages(self, instance: Union[Dict, List]) -> List[List[Dict[str, Any]]]: |
|
if isinstance(instance["source"], str) and "media" in instance: |
|
return self._create_messages_from_instance(instance) |
|
|
|
messages = super().to_messages(instance) |
|
self.verify_messages(messages) |
|
|
|
|
|
return [messages] |
|
|
|
def _send_requests( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_logprobs: bool, |
|
return_meta_data: bool, |
|
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]: |
|
params = self.to_dict([WMLChatParamsMixin], keep_empty=False) |
|
|
|
if return_logprobs: |
|
output_type = "logprobs" |
|
params["logprobs"] = True |
|
else: |
|
output_type = "message" |
|
params["logprobs"] = False |
|
|
|
final_results = [] |
|
|
|
for instance in dataset: |
|
messages = self.to_messages(instance) |
|
|
|
for message in messages: |
|
result = self._model.chat( |
|
messages=message, |
|
params=params, |
|
) |
|
|
|
final_results.append( |
|
self.get_return_object( |
|
result["choices"][0][output_type]["content"], |
|
result, |
|
instance["source"], |
|
return_meta_data, |
|
) |
|
) |
|
|
|
return final_results |
|
|
|
def get_return_object(self, predict_result, result, input_text, return_meta_data): |
|
if return_meta_data: |
|
return TextGenerationInferenceOutput( |
|
prediction=predict_result, |
|
input_tokens=result["usage"]["prompt_tokens"], |
|
output_tokens=len(predict_result) |
|
if isinstance(predict_result, list) |
|
else None, |
|
model_name=self.model_name or self.deployment_id, |
|
inference_type=self.label, |
|
stop_reason=result["choices"][0]["finish_reason"], |
|
input_text=input_text, |
|
) |
|
return predict_result |
|
|
|
|
|
@deprecation( |
|
version="2.0.0", |
|
msg=" Please use either 'WMLInferenceEngineGeneration' or 'WMLInferenceEngineChat'" |
|
" depending on your task.", |
|
) |
|
class WMLInferenceEngine(WMLInferenceEngineGeneration): |
|
def prepare_engine(self): |
|
super().prepare_engine() |
|
get_logger().warning("'WMLInferenceEngine' is deprecated") |
|
|
|
|
|
def get_images_without_text(instance): |
|
return extract_images(instance["source"], instance) |
|
|
|
|
|
def get_text_without_images(instance, image_token="<image>"): |
|
regex = r"<" + f"{constants.image_tag}" + r'\s+src=["\'](.*?)["\']\s*/?>' |
|
return re.sub(regex, image_token, instance["source"]) |
|
|
|
|
|
class LMMSEvalBaseInferenceEngine( |
|
InferenceEngine, PackageRequirementsMixin, LazyLoadMixin |
|
): |
|
model_type: str |
|
model_args: Dict[str, str] |
|
batch_size: int = 1 |
|
image_token = "<image>" |
|
|
|
_requirements_list = { |
|
"lmms_eval": "Install llms-eval package using 'pip install lmms-eval==0.2.4'", |
|
} |
|
|
|
def prepare_engine(self): |
|
if not self.lazy_load: |
|
self._prepare_engine() |
|
|
|
def _prepare_engine(self): |
|
import torch |
|
from lmms_eval.api.instance import Instance |
|
from lmms_eval.models import get_model |
|
|
|
self.new_instance = Instance |
|
|
|
self.device = torch.device( |
|
"mps" |
|
if torch.backends.mps.is_available() |
|
else "cuda" |
|
if torch.cuda.is_available() |
|
else "cpu" |
|
) |
|
|
|
if isinstance(self.model_args, dict): |
|
self.model_args = ",".join(f"{k}={v}" for k, v in self.model_args.items()) |
|
|
|
self.model = get_model(self.model_type).create_from_arg_string( |
|
self.model_args, |
|
{ |
|
"batch_size": self.batch_size, |
|
"device": self.device, |
|
}, |
|
) |
|
|
|
def _is_loaded(self): |
|
return hasattr(self, "model") and self.model is not None |
|
|
|
|
|
class LMMSEvalInferenceEngine(LMMSEvalBaseInferenceEngine): |
|
max_new_tokens: int = 32 |
|
temperature: float = 0.0 |
|
do_sample: bool = False |
|
generate_until: List[str] = ["\n\n"] |
|
|
|
def _infer( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[str], List[TextGenerationInferenceOutput]]: |
|
if not self._is_loaded(): |
|
self._prepare_engine() |
|
|
|
from lmms_eval.api.instance import Instance |
|
|
|
temp_task_name = str(uuid.uuid4()) |
|
|
|
requests = [] |
|
for i, instance in enumerate(dataset): |
|
requests.append( |
|
Instance( |
|
request_type="generate_until", |
|
arguments=( |
|
get_text_without_images(instance, image_token=self.image_token), |
|
{ |
|
"max_new_tokens": self.max_new_tokens, |
|
"temperature": self.temperature, |
|
"do_sample": self.do_sample, |
|
"until": self.generate_until, |
|
}, |
|
get_images_without_text, |
|
i, |
|
temp_task_name, |
|
"test", |
|
), |
|
idx=i, |
|
metadata={ |
|
"task": temp_task_name, |
|
"doc_id": i, |
|
"repeats": 1, |
|
}, |
|
) |
|
) |
|
|
|
self.model.task_dict[temp_task_name] = DatasetDict({"test": dataset}) |
|
|
|
responses = self.model.generate_until(requests) |
|
|
|
self.model.task_dict.pop(temp_task_name) |
|
|
|
return responses |
|
|
|
|
|
class LMMSEvalLoglikelihoodInferenceEngine(LMMSEvalBaseInferenceEngine): |
|
request_type: Literal["loglikelihood"] = "loglikelihood" |
|
|
|
def make_instance(self, instance, special_args, index, task_name): |
|
from lmms_eval.api.instance import Instance |
|
|
|
return Instance( |
|
request_type=self.request_type, |
|
arguments=( |
|
get_text_without_images(instance, image_token=self.image_token), |
|
special_args, |
|
get_images_without_text, |
|
index, |
|
task_name, |
|
"test", |
|
), |
|
idx=index, |
|
metadata={ |
|
"task": task_name, |
|
"doc_id": index, |
|
"repeats": 1, |
|
}, |
|
) |
|
|
|
def _infer( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[str], List[TextGenerationInferenceOutput]]: |
|
if not self._is_loaded(): |
|
self._prepare_engine() |
|
|
|
temp_task_name = str(uuid.uuid4()) |
|
|
|
requests = [] |
|
for i, instance in enumerate(dataset): |
|
task_data = instance["task_data"] |
|
|
|
if isinstance(task_data, str): |
|
task_data = json.loads(task_data) |
|
|
|
for option in task_data["options"]: |
|
requests.append( |
|
self.make_instance( |
|
instance, |
|
option, |
|
i, |
|
temp_task_name, |
|
) |
|
) |
|
|
|
self.model.task_dict[temp_task_name] = DatasetDict({"test": dataset}) |
|
self.model.metadata = {} |
|
|
|
responses = self.model.loglikelihood(requests) |
|
|
|
self.model.task_dict.pop(temp_task_name) |
|
|
|
optimal_scores = [sys.float_info.max] * len(dataset) |
|
optimal_responses = [None] * len(dataset) |
|
|
|
for request, (score, _) in zip(requests, responses): |
|
if score < optimal_scores[request.idx]: |
|
optimal_scores[request.idx] = score |
|
optimal_responses[request.idx] = request.arguments[1] |
|
|
|
return optimal_responses |
|
|
|
|
|
class VLLMInferenceEngine( |
|
InferenceEngine, PackageRequirementsMixin, StandardAPIParamsMixin |
|
): |
|
def prepare_engine(self): |
|
from vllm import LLM, SamplingParams |
|
|
|
args = self.to_dict([StandardAPIParamsMixin]) |
|
self.sampling_params = SamplingParams(**args) |
|
self.llm = LLM(model=self.model) |
|
|
|
def _infer( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[str], List[TextGenerationInferenceOutput]]: |
|
inputs = [] |
|
for instance in dataset: |
|
inputs.append(instance["source"]) |
|
|
|
if isinstance(inputs[0], list): |
|
outputs = self.llm.chat(inputs, self.sampling_params) |
|
else: |
|
outputs = self.llm.generate(inputs, self.sampling_params) |
|
|
|
predictions = [] |
|
for output in outputs: |
|
predictions.append(output.outputs[0].text) |
|
|
|
return predictions |
|
|
|
|
|
class AsyncTokenBucket: |
|
def __init__(self, rate, capacity): |
|
self.rate = rate |
|
self.capacity = capacity |
|
self.tokens = capacity |
|
self.timestamp = time.perf_counter() |
|
self.lock = asyncio.Lock() |
|
self.interval = 1.0 / self.rate |
|
|
|
async def acquire(self, tokens=1): |
|
while True: |
|
async with self.lock: |
|
now = time.perf_counter() |
|
delta = now - self.timestamp |
|
|
|
|
|
token_intervals = int(delta / self.interval) |
|
if token_intervals > 0: |
|
self.tokens = min(self.capacity, self.tokens + token_intervals) |
|
self.timestamp += token_intervals * self.interval |
|
logging.debug( |
|
f"Added {token_intervals} tokens. Tokens now: {self.tokens}" |
|
) |
|
|
|
if self.tokens >= tokens: |
|
self.tokens -= tokens |
|
logging.debug(f"Token acquired. Tokens left: {self.tokens}") |
|
return |
|
|
|
time_until_next_token = self.interval - (now - self.timestamp) |
|
logging.debug( |
|
f"Not enough tokens. Need to wait {time_until_next_token:.4f} seconds." |
|
) |
|
|
|
await asyncio.sleep(time_until_next_token) |
|
|
|
|
|
class LiteLLMInferenceEngine( |
|
InferenceEngine, StandardAPIParamsMixin, PackageRequirementsMixin |
|
): |
|
max_requests_per_second: float = 6 |
|
max_retries: int = 5 |
|
|
|
_requirements_list: list = ["litellm", "tenacity", "tqdm", "diskcache"] |
|
|
|
def prepare_engine(self): |
|
|
|
self._rate_limiter = AsyncTokenBucket( |
|
rate=self.max_requests_per_second, |
|
capacity=self.max_requests_per_second, |
|
) |
|
self.inference_type = "litellm" |
|
import litellm |
|
from litellm import acompletion |
|
from litellm.caching.caching import Cache |
|
|
|
litellm.cache = Cache(type="disk") |
|
|
|
self._completion = acompletion |
|
|
|
self._semaphore = asyncio.Semaphore(self.max_requests_per_second) |
|
|
|
async def _infer_instance( |
|
self, index: int, instance: Dict[str, Any] |
|
) -> TextGenerationInferenceOutput: |
|
"""Process a single inference request.""" |
|
async with self._semaphore: |
|
await self._rate_limiter.acquire() |
|
|
|
await asyncio.sleep(0.01) |
|
messages = self.to_messages(instance) |
|
kwargs = self.to_dict([StandardAPIParamsMixin]) |
|
try: |
|
response = await self._completion( |
|
messages=messages, |
|
max_retries=self.max_retries, |
|
caching=True, |
|
**kwargs, |
|
) |
|
except Exception as e: |
|
raise RuntimeError( |
|
f"Error inferring the following instance:\n{instance}" |
|
) from e |
|
|
|
usage = response.get("usage", {}) |
|
return TextGenerationInferenceOutput( |
|
prediction=response["choices"][0]["message"]["content"], |
|
input_tokens=usage.get("prompt_tokens"), |
|
output_tokens=usage.get("completion_tokens"), |
|
model_name=response.get("model", self.model), |
|
inference_type=self.inference_type, |
|
) |
|
|
|
async def _infer_async( |
|
self, dataset: List[Dict[str, Any]] |
|
) -> List[TextGenerationInferenceOutput]: |
|
"""Process multiple inference requests concurrently with a progress bar.""" |
|
tasks = [ |
|
self._infer_instance(i, instance) for i, instance in enumerate(dataset) |
|
] |
|
|
|
return await tqdm_asyncio.gather( |
|
*tasks, desc=f"LiteLLM Inference ({self.model})", total=len(tasks) |
|
) |
|
|
|
def _infer( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], "DatasetDict"], |
|
return_meta_data: bool = False, |
|
) -> Union[List[str], List[TextGenerationInferenceOutput]]: |
|
"""Main inference entry point.""" |
|
loop = asyncio.get_event_loop() |
|
responses = loop.run_until_complete(self._infer_async(dataset)) |
|
|
|
if return_meta_data: |
|
return responses |
|
|
|
return [response.prediction for response in responses] |
|
|
|
|
|
_supported_apis = Literal[ |
|
"watsonx", "together-ai", "open-ai", "aws", "ollama", "bam", "watsonx-sdk", "rits" |
|
] |
|
|
|
|
|
class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin): |
|
"""Inference engine capable of dynamically switching between multiple providers APIs. |
|
|
|
This class extends the InferenceEngine and OpenAiInferenceEngineParamsMixin |
|
to enable seamless integration with various API providers. The supported APIs are |
|
specified in `_supported_apis`, allowing users to interact with multiple models |
|
from different sources. The `api_model_map` dictionary maps each API to |
|
specific model identifiers, enabling automatic configuration based on |
|
user requests. |
|
|
|
Attributes: |
|
provider: Optional; Specifies the current API in use. Must be one of the |
|
literals in `_supported_apis`. |
|
provider_model_map: Dictionary mapping each supported API to a corresponding |
|
model identifier string. This mapping allows consistent access to models |
|
across different API backends. |
|
""" |
|
|
|
provider: Optional[_supported_apis] = None |
|
|
|
provider_model_map: Dict[_supported_apis, Dict[str, str]] = { |
|
"watsonx": { |
|
"llama-3-8b-instruct": "watsonx/meta-llama/llama-3-8b-instruct", |
|
"llama-3-70b-instruct": "watsonx/meta-llama/llama-3-70b-instruct", |
|
"granite-3-8b-instruct": "watsonx/ibm/granite-3-8b-instruct", |
|
"flan-t5-xxl": "watsonx/google/flan-t5-xxl", |
|
"llama-3-2-1b-instruct": "watsonx/meta-llama/llama-3-2-1b-instruct", |
|
"llama-3-2-11b-vision-instruct": "watsonx/meta-llama/llama-3-2-11b-vision-instruct", |
|
"llama-3-2-90b-vision-instruct": "watsonx/meta-llama/llama-3-2-90b-vision-instruct", |
|
}, |
|
"watsonx-sdk": { |
|
"llama-3-8b-instruct": "meta-llama/llama-3-8b-instruct", |
|
"llama-3-70b-instruct": "meta-llama/llama-3-70b-instruct", |
|
"granite-3-8b-instruct": "ibm/granite-3-8b-instruct", |
|
}, |
|
"together-ai": { |
|
"llama-3-8b-instruct": "together_ai/togethercomputer/llama-3-8b-instruct", |
|
"llama-3-70b-instruct": "together_ai/togethercomputer/llama-3-70b-instruct", |
|
"llama-3-2-1b-instruct": "together_ai/togethercomputer/llama-3-2-1b-instruct", |
|
}, |
|
"aws": { |
|
"llama-3-8b-instruct": "bedrock/meta.llama3-8b-instruct-v1:0", |
|
"llama-3-70b-instruct": "bedrock/meta.llama3-70b-instruct-v1:0", |
|
}, |
|
"ollama": { |
|
"llama-3-8b-instruct": "llama3:8b", |
|
"llama-3-70b-instruct": "llama3:70b", |
|
}, |
|
"bam": { |
|
"granite-3-8b-instruct": "ibm/granite-8b-instruct-preview-4k", |
|
"llama-3-8b-instruct": "meta-llama/llama-3-8b-instruct", |
|
"llama-3-2-1b-instruct": "meta-llama/llama-3-2-1b-instruct", |
|
"flan-t5-xxl": "google/flan-t5-xxl", |
|
}, |
|
"rits": { |
|
"granite-3-8b-instruct": "ibm-granite/granite-3.0-8b-instruct", |
|
"llama-3-1-8b-instruct": "meta-llama/llama-3-1-8b-instruct", |
|
"llama-3-1-70b-instruct": "meta-llama/llama-3-1-70b-instruct", |
|
"llama-3-2-11b-vision-instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct", |
|
"llama-3-2-90b-vision-instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct", |
|
"mistral-large-instruct": "mistralai/mistral-large-instruct-2407", |
|
"mixtral-8x7b-instruct": "mistralai/mixtral-8x7B-instruct-v0.1", |
|
}, |
|
} |
|
|
|
_provider_to_base_class = { |
|
"watsonx": LiteLLMInferenceEngine, |
|
"open-ai": LiteLLMInferenceEngine, |
|
"together-ai": LiteLLMInferenceEngine, |
|
"aws": LiteLLMInferenceEngine, |
|
"ollama": OllamaInferenceEngine, |
|
"bam": IbmGenAiInferenceEngine, |
|
"watsonx-sdk": WMLInferenceEngine, |
|
"rits": RITSInferenceEngine, |
|
} |
|
|
|
_provider_param_renaming = { |
|
"bam": {"max_tokens": "max_new_tokens", "model": "model_name"}, |
|
"watsonx-sdk": {"max_tokens": "max_new_tokens", "model": "model_name"}, |
|
"rits": {"model": "model_name"}, |
|
} |
|
|
|
def get_provider_name(self): |
|
return self.provider if self.provider is not None else settings.default_provider |
|
|
|
def prepare_engine(self): |
|
provider = self.get_provider_name() |
|
if provider not in self._provider_to_base_class: |
|
raise UnitxtError( |
|
f"{provider} is not a configured API for CrossProviderInferenceEngine. Supported apis: {','.join(self.provider_model_map.keys())}" |
|
) |
|
if self.model not in self.provider_model_map[provider]: |
|
raise UnitxtError( |
|
f"{self.model} is not configured for provider {provider}. Supported models: {','.join(self.provider_model_map[provider].keys())}" |
|
) |
|
cls = self.__class__._provider_to_base_class[provider] |
|
args = self.to_dict([StandardAPIParamsMixin]) |
|
args["model"] = self.provider_model_map[provider][self.model] |
|
params = list(args.keys()) |
|
if provider in self._provider_param_renaming: |
|
for param in params: |
|
if args[param] is not None: |
|
if param in self._provider_param_renaming[provider]: |
|
args[self._provider_param_renaming[provider][param]] = args[ |
|
param |
|
] |
|
del args[param] |
|
else: |
|
del args[param] |
|
self.engine = cls(**args) |
|
|
|
def _infer( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[str], List[TextGenerationInferenceOutput]]: |
|
return self.engine._infer(dataset, return_meta_data) |
|
|
|
def get_engine_id(self): |
|
api = self.get_provider_name() |
|
return get_model_and_label_id(self.provider_model_map[api][self.model], api) |
|
|
|
|
|
class HFOptionSelectingInferenceEngine(InferenceEngine): |
|
"""HuggingFace based class for inference engines that calculate log probabilities. |
|
|
|
This class uses models from the HuggingFace Transformers library to calculate log probabilities for text inputs. |
|
""" |
|
|
|
model_name: str |
|
batch_size: int |
|
|
|
_requirements_list = { |
|
"transformers": "Install huggingface package using 'pip install --upgrade transformers" |
|
} |
|
|
|
def prepare_engine(self): |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
self.device = torch.device( |
|
"mps" |
|
if torch.backends.mps.is_available() |
|
else "cuda" |
|
if torch.cuda.is_available() |
|
else "cpu" |
|
) |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
self.model = AutoModelForCausalLM.from_pretrained(self.model_name).to( |
|
self.device |
|
) |
|
|
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
def get_log_probs(self, texts): |
|
|
|
import torch |
|
from tqdm import tqdm |
|
|
|
log_probs = [] |
|
|
|
|
|
for i in tqdm(range(0, len(texts), self.batch_size)): |
|
batch = texts[i : i + self.batch_size] |
|
|
|
|
|
if isinstance(texts[0], list): |
|
batch = self.tokenizer.apply_chat_template(batch, tokenize=False) |
|
|
|
inputs = self.tokenizer( |
|
batch, return_tensors="pt", padding=True, truncation=True |
|
).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
predictions = self.model(**inputs) |
|
logits = predictions.logits |
|
|
|
for j in range(len(batch)): |
|
input_ids = inputs.input_ids[j] |
|
text_logits = logits[j, :-1, :] |
|
text_log_probs = torch.log_softmax(text_logits, dim=-1) |
|
|
|
|
|
token_log_probs = text_log_probs[ |
|
torch.arange(text_logits.shape[0]), input_ids[1:] |
|
] |
|
|
|
|
|
sequence_log_prob = token_log_probs.sum().item() |
|
log_probs.append(sequence_log_prob) |
|
|
|
return log_probs |
|
|
|
def _infer( |
|
self, |
|
dataset: Union[List[Dict[str, Any]], DatasetDict], |
|
return_meta_data: bool = False, |
|
) -> Union[List[str], List[TextGenerationInferenceOutput]]: |
|
inputs = [] |
|
|
|
for instance in dataset: |
|
for option in instance["task_data"]["options"]: |
|
if isinstance(instance["source"], list): |
|
inputs.append( |
|
instance["source"] + [{"role": "assistant", "content": option}] |
|
) |
|
else: |
|
inputs.append(instance["source"] + option) |
|
|
|
scores = self.get_log_probs(inputs) |
|
|
|
scores_iterator = iter(scores) |
|
|
|
predictions = [] |
|
for instance in dataset: |
|
options_scores = Counter() |
|
for option in instance["task_data"]["options"]: |
|
score = next(scores_iterator) |
|
options_scores[option] = score |
|
predictions.append(options_scores.most_common(1)[0][0]) |
|
|
|
return predictions |
|
|