import dataclasses import logging import math import os import io import sys import time import json from typing import Optional, Sequence, Union import openai import tqdm from openai import openai_object import copy StrOrOpenAIObject = Union[str, openai_object.OpenAIObject] openai_org = os.getenv("OPENAI_ORG") if openai_org is not None: openai.organization = openai_org logging.warning(f"Switching to organization: {openai_org} for OAI API key.") @dataclasses.dataclass class OpenAIDecodingArguments(object): max_tokens: int = 1800 temperature: float = 0.2 top_p: float = 1.0 n: int = 1 stream: bool = False stop: Optional[Sequence[str]] = None presence_penalty: float = 0.0 frequency_penalty: float = 0.0 suffix: Optional[str] = None logprobs: Optional[int] = None echo: bool = False def openai_completion( prompts: Union[str, Sequence[str], Sequence[dict[str, str]], dict[str, str]], decoding_args: OpenAIDecodingArguments, model_name="text-davinci-003", sleep_time=2, batch_size=1, max_instances=sys.maxsize, max_batches=sys.maxsize, return_text=False, **decoding_kwargs, ) -> Union[Union[StrOrOpenAIObject], Sequence[StrOrOpenAIObject], Sequence[Sequence[StrOrOpenAIObject]],]: """Decode with OpenAI API. Args: prompts: A string or a list of strings to complete. If it is a chat model the strings should be formatted as explained here: https://github.com/openai/openai-python/blob/main/chatml.md. If it is a chat model it can also be a dictionary (or list thereof) as explained here: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb decoding_args: Decoding arguments. model_name: Model name. Can be either in the format of "org/model" or just "model". sleep_time: Time to sleep once the rate-limit is hit. batch_size: Number of prompts to send in a single request. Only for non chat model. max_instances: Maximum number of prompts to decode. max_batches: Maximum number of batches to decode. This argument will be deprecated in the future. return_text: If True, return text instead of full completion object (which contains things like logprob). decoding_kwargs: Additional decoding arguments. Pass in `best_of` and `logit_bias` if you need them. Returns: A completion or a list of completions. Depending on return_text, return_openai_object, and decoding_args.n, the completion type can be one of - a string (if return_text is True) - an openai_object.OpenAIObject object (if return_text is False) - a list of objects of the above types (if decoding_args.n > 1) """ is_single_prompt = isinstance(prompts, (str, dict)) if is_single_prompt: prompts = [prompts] if max_batches < sys.maxsize: logging.warning( "`max_batches` will be deprecated in the future, please use `max_instances` instead." "Setting `max_instances` to `max_batches * batch_size` for now." ) max_instances = max_batches * batch_size prompts = prompts[:max_instances] num_prompts = len(prompts) prompt_batches = [ prompts[batch_id * batch_size : (batch_id + 1) * batch_size] for batch_id in range(int(math.ceil(num_prompts / batch_size))) ] completions = [] for batch_id, prompt_batch in tqdm.tqdm( enumerate(prompt_batches), desc="prompt_batches", total=len(prompt_batches), ): batch_decoding_args = copy.deepcopy(decoding_args) # cloning the decoding_args while True: try: shared_kwargs = dict( model=model_name, **batch_decoding_args.__dict__, **decoding_kwargs, ) completion_batch = openai.Completion.create(prompt=prompt_batch, **shared_kwargs) choices = completion_batch.choices for choice in choices: choice["total_tokens"] = completion_batch.usage.total_tokens completions.extend(choices) break except openai.error.OpenAIError as e: logging.warning(f"OpenAIError: {e}.") if "Please reduce your prompt" in str(e): batch_decoding_args.max_tokens = int(batch_decoding_args.max_tokens * 0.8) logging.warning(f"Reducing target length to {batch_decoding_args.max_tokens}, Retrying...") else: logging.warning("Hit request rate limit; retrying...") time.sleep(sleep_time) # Annoying rate limit on requests. if return_text: completions = [completion.text for completion in completions] if decoding_args.n > 1: # make completions a nested list, where each entry is a consecutive decoding_args.n of original entries. completions = [completions[i : i + decoding_args.n] for i in range(0, len(completions), decoding_args.n)] if is_single_prompt: # Return non-tuple if only 1 input and 1 generation. (completions,) = completions return completions def _make_w_io_base(f, mode: str): if not isinstance(f, io.IOBase): f_dirname = os.path.dirname(f) if f_dirname != "": os.makedirs(f_dirname, exist_ok=True) f = open(f, mode=mode) return f def _make_r_io_base(f, mode: str): if not isinstance(f, io.IOBase): f = open(f, mode=mode) return f def jdump(obj, f, mode="w", indent=4, default=str): """Dump a str or dictionary to a file in json format. Args: obj: An object to be written. f: A string path to the location on disk. mode: Mode for opening the file. indent: Indent for storing json dictionaries. default: A function to handle non-serializable entries; defaults to `str`. """ f = _make_w_io_base(f, mode) if isinstance(obj, (dict, list)): json.dump(obj, f, indent=indent, default=default) elif isinstance(obj, str): f.write(obj) else: raise ValueError(f"Unexpected type: {type(obj)}") f.close() def jload(f, mode="r"): """Load a .json file into a dictionary.""" f = _make_r_io_base(f, mode) jdict = json.load(f) f.close() return jdict