import logging import os from abc import ABC, abstractmethod from gradio_client import Client import openai import tiktoken from transformers import T5Tokenizer HUB_TOKEN = os.getenv("HUB_TOKEN") logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class TextCompletion(ABC): @abstractmethod def get_token_count(prompt: str) -> int: ... @abstractmethod def complete(prompt: str): ... def get_score_multiplier(self) -> float: return 1.0 class DebugCompleter(TextCompletion): """This is only meant for testing/debugging purposes.""" def __init__(self, model): self.model = model def get_token_count(self, prompt: str) -> int: encoding = tiktoken.encoding_for_model("text-davinci-003") return len(encoding.encode(prompt)) def complete(self, prompt: str): if "$CORRECT$" in prompt: # returns the correct prompt return "I have been PWNED" elif "$ERROR$" in prompt: # simulate an error raise RuntimeError(f"User raised an error explicitly in {self.model}") else: return "Some incorrect completion" class GPTCompletor(TextCompletion): def __init__(self, model): self.openai_api_key = os.getenv("OPENAI_API_KEY") self.model = model def get_token_count(self, prompt: str) -> int: encoding = tiktoken.encoding_for_model("text-davinci-003") return len(encoding.encode(prompt)) def complete(self, prompt: str): # Set the api key openai.api_key = self.openai_api_key # Call the API to generate a response response = openai.Completion.create( engine=self.model, prompt=prompt, max_tokens=1024, temperature=0, top_p=0, frequency_penalty=1, presence_penalty=1, ) # Get the response text response_text = response["choices"][0]["text"] return response_text class ChatGPTCompletor(TextCompletion): def __init__(self, openai_api_key, model): self.openai_api_key = openai_api_key self.model = model def get_token_count(self, prompt: str) -> int: encoding = tiktoken.encoding_for_model(self.model) return len(encoding.encode(prompt)) def complete(self, prompt: str): # Set the api key openai.api_key = self.openai_api_key messages = [ {"role": "user", "content": prompt}, ] # Call the API to generate a response response = openai.ChatCompletion.create( messages=messages, model=self.model, temperature=0, ) # Get the response text response_text = response["choices"][0]["message"]["content"] return response_text def get_score_multiplier(self) -> float: return 2.0 completers = { "text-davinci-003": GPTCompletor, "gpt-3.5-turbo": ChatGPTCompletor, } def get_completer(model: str, openai_api_key: str = ""): logger.info(f"Loading completer for {model=}") if model in ["text-davinci-003", "gpt-3.5-turbo"]: completer = completers[model](model=model, openai_api_key=openai_api_key) elif model == "gpt-debug": # Not included in completers because we dont want it showing in the app logger.warning("Using debugging completer...") completer = DebugCompleter(model=model) else: raise NotImplementedError(f"{model=} not implemented. Model must be onf of {list(completers.keys())}") return completer