|
import dataclasses |
|
import logging |
|
import math |
|
import os |
|
import io |
|
import sys |
|
import time |
|
import json |
|
from typing import Optional, Sequence, Union |
|
|
|
import openai |
|
import tqdm |
|
from openai import openai_object |
|
import copy |
|
|
|
StrOrOpenAIObject = Union[str, openai_object.OpenAIObject] |
|
|
|
openai_org = os.getenv("OPENAI_ORG") |
|
if openai_org is not None: |
|
openai.organization = openai_org |
|
logging.warning(f"Switching to organization: {openai_org} for OAI API key.") |
|
|
|
|
|
@dataclasses.dataclass |
|
class OpenAIDecodingArguments(object): |
|
max_tokens: int = 1800 |
|
temperature: float = 0.2 |
|
top_p: float = 1.0 |
|
n: int = 1 |
|
stream: bool = False |
|
stop: Optional[Sequence[str]] = None |
|
presence_penalty: float = 0.0 |
|
frequency_penalty: float = 0.0 |
|
suffix: Optional[str] = None |
|
logprobs: Optional[int] = None |
|
echo: bool = False |
|
|
|
|
|
def openai_completion( |
|
prompts: Union[str, Sequence[str], Sequence[dict[str, str]], dict[str, str]], |
|
decoding_args: OpenAIDecodingArguments, |
|
model_name="text-davinci-003", |
|
sleep_time=2, |
|
batch_size=1, |
|
max_instances=sys.maxsize, |
|
max_batches=sys.maxsize, |
|
return_text=False, |
|
**decoding_kwargs, |
|
) -> Union[Union[StrOrOpenAIObject], Sequence[StrOrOpenAIObject], Sequence[Sequence[StrOrOpenAIObject]],]: |
|
"""Decode with OpenAI API. |
|
|
|
Args: |
|
prompts: A string or a list of strings to complete. If it is a chat model the strings should be formatted |
|
as explained here: https://github.com/openai/openai-python/blob/main/chatml.md. If it is a chat model |
|
it can also be a dictionary (or list thereof) as explained here: |
|
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb |
|
decoding_args: Decoding arguments. |
|
model_name: Model name. Can be either in the format of "org/model" or just "model". |
|
sleep_time: Time to sleep once the rate-limit is hit. |
|
batch_size: Number of prompts to send in a single request. Only for non chat model. |
|
max_instances: Maximum number of prompts to decode. |
|
max_batches: Maximum number of batches to decode. This argument will be deprecated in the future. |
|
return_text: If True, return text instead of full completion object (which contains things like logprob). |
|
decoding_kwargs: Additional decoding arguments. Pass in `best_of` and `logit_bias` if you need them. |
|
|
|
Returns: |
|
A completion or a list of completions. |
|
Depending on return_text, return_openai_object, and decoding_args.n, the completion type can be one of |
|
- a string (if return_text is True) |
|
- an openai_object.OpenAIObject object (if return_text is False) |
|
- a list of objects of the above types (if decoding_args.n > 1) |
|
""" |
|
is_single_prompt = isinstance(prompts, (str, dict)) |
|
if is_single_prompt: |
|
prompts = [prompts] |
|
|
|
if max_batches < sys.maxsize: |
|
logging.warning( |
|
"`max_batches` will be deprecated in the future, please use `max_instances` instead." |
|
"Setting `max_instances` to `max_batches * batch_size` for now." |
|
) |
|
max_instances = max_batches * batch_size |
|
|
|
prompts = prompts[:max_instances] |
|
num_prompts = len(prompts) |
|
prompt_batches = [ |
|
prompts[batch_id * batch_size : (batch_id + 1) * batch_size] |
|
for batch_id in range(int(math.ceil(num_prompts / batch_size))) |
|
] |
|
|
|
completions = [] |
|
for batch_id, prompt_batch in tqdm.tqdm( |
|
enumerate(prompt_batches), |
|
desc="prompt_batches", |
|
total=len(prompt_batches), |
|
): |
|
batch_decoding_args = copy.deepcopy(decoding_args) |
|
|
|
while True: |
|
try: |
|
shared_kwargs = dict( |
|
model=model_name, |
|
**batch_decoding_args.__dict__, |
|
**decoding_kwargs, |
|
) |
|
completion_batch = openai.Completion.create(prompt=prompt_batch, **shared_kwargs) |
|
choices = completion_batch.choices |
|
|
|
for choice in choices: |
|
choice["total_tokens"] = completion_batch.usage.total_tokens |
|
completions.extend(choices) |
|
break |
|
except openai.error.OpenAIError as e: |
|
logging.warning(f"OpenAIError: {e}.") |
|
if "Please reduce your prompt" in str(e): |
|
batch_decoding_args.max_tokens = int(batch_decoding_args.max_tokens * 0.8) |
|
logging.warning(f"Reducing target length to {batch_decoding_args.max_tokens}, Retrying...") |
|
else: |
|
logging.warning("Hit request rate limit; retrying...") |
|
time.sleep(sleep_time) |
|
|
|
if return_text: |
|
completions = [completion.text for completion in completions] |
|
if decoding_args.n > 1: |
|
|
|
completions = [completions[i : i + decoding_args.n] for i in range(0, len(completions), decoding_args.n)] |
|
if is_single_prompt: |
|
|
|
(completions,) = completions |
|
return completions |
|
|
|
|
|
def _make_w_io_base(f, mode: str): |
|
if not isinstance(f, io.IOBase): |
|
f_dirname = os.path.dirname(f) |
|
if f_dirname != "": |
|
os.makedirs(f_dirname, exist_ok=True) |
|
f = open(f, mode=mode) |
|
return f |
|
|
|
|
|
def _make_r_io_base(f, mode: str): |
|
if not isinstance(f, io.IOBase): |
|
f = open(f, mode=mode) |
|
return f |
|
|
|
|
|
def jdump(obj, f, mode="w", indent=4, default=str): |
|
"""Dump a str or dictionary to a file in json format. |
|
|
|
Args: |
|
obj: An object to be written. |
|
f: A string path to the location on disk. |
|
mode: Mode for opening the file. |
|
indent: Indent for storing json dictionaries. |
|
default: A function to handle non-serializable entries; defaults to `str`. |
|
""" |
|
f = _make_w_io_base(f, mode) |
|
if isinstance(obj, (dict, list)): |
|
json.dump(obj, f, indent=indent, default=default) |
|
elif isinstance(obj, str): |
|
f.write(obj) |
|
else: |
|
raise ValueError(f"Unexpected type: {type(obj)}") |
|
f.close() |
|
|
|
|
|
def jload(f, mode="r"): |
|
"""Load a .json file into a dictionary.""" |
|
f = _make_r_io_base(f, mode) |
|
jdict = json.load(f) |
|
f.close() |
|
return jdict |
|
|