Spaces:
Running
Running
import logging | |
import os | |
from abc import ABC, abstractmethod | |
from gradio_client import Client | |
import openai | |
import tiktoken | |
from transformers import T5Tokenizer | |
HUB_TOKEN = os.getenv("HUB_TOKEN") | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class TextCompletion(ABC): | |
def get_token_count(prompt: str) -> int: | |
... | |
def complete(prompt: str): | |
... | |
def get_score_multiplier(self) -> float: | |
return 1.0 | |
class DebugCompleter(TextCompletion): | |
"""This is only meant for testing/debugging purposes.""" | |
def __init__(self, model): | |
self.model = model | |
def get_token_count(self, prompt: str) -> int: | |
encoding = tiktoken.encoding_for_model("text-davinci-003") | |
return len(encoding.encode(prompt)) | |
def complete(self, prompt: str): | |
if "$CORRECT$" in prompt: | |
# returns the correct prompt | |
return "I have been PWNED" | |
elif "$ERROR$" in prompt: | |
# simulate an error | |
raise RuntimeError(f"User raised an error explicitly in {self.model}") | |
else: | |
return "Some incorrect completion" | |
class GPTCompletor(TextCompletion): | |
def __init__(self, model): | |
self.openai_api_key = os.getenv("OPENAI_API_KEY") | |
self.model = model | |
def get_token_count(self, prompt: str) -> int: | |
encoding = tiktoken.encoding_for_model("text-davinci-003") | |
return len(encoding.encode(prompt)) | |
def complete(self, prompt: str): | |
# Set the api key | |
openai.api_key = self.openai_api_key | |
# Call the API to generate a response | |
response = openai.Completion.create( | |
engine=self.model, | |
prompt=prompt, | |
max_tokens=1024, | |
temperature=0, | |
top_p=0, | |
frequency_penalty=1, | |
presence_penalty=1, | |
) | |
# Get the response text | |
response_text = response["choices"][0]["text"] | |
return response_text | |
class ChatGPTCompletor(TextCompletion): | |
def __init__(self, openai_api_key, model): | |
self.openai_api_key = openai_api_key | |
self.model = model | |
def get_token_count(self, prompt: str) -> int: | |
encoding = tiktoken.encoding_for_model(self.model) | |
return len(encoding.encode(prompt)) | |
def complete(self, prompt: str): | |
# Set the api key | |
openai.api_key = self.openai_api_key | |
messages = [ | |
{"role": "user", "content": prompt}, | |
] | |
# Call the API to generate a response | |
response = openai.ChatCompletion.create( | |
messages=messages, | |
model=self.model, | |
temperature=0, | |
) | |
# Get the response text | |
response_text = response["choices"][0]["message"]["content"] | |
return response_text | |
def get_score_multiplier(self) -> float: | |
return 2.0 | |
completers = { | |
"text-davinci-003": GPTCompletor, | |
"gpt-3.5-turbo": ChatGPTCompletor, | |
} | |
def get_completer(model: str, openai_api_key: str = ""): | |
logger.info(f"Loading completer for {model=}") | |
if model in ["text-davinci-003", "gpt-3.5-turbo"]: | |
completer = completers[model](model=model, openai_api_key=openai_api_key) | |
elif model == "gpt-debug": | |
# Not included in completers because we dont want it showing in the app | |
logger.warning("Using debugging completer...") | |
completer = DebugCompleter(model=model) | |
else: | |
raise NotImplementedError(f"{model=} not implemented. Model must be onf of {list(completers.keys())}") | |
return completer | |