|
import os |
|
from io import BytesIO |
|
import pandas as pd |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
import openai |
|
import streamlit as st |
|
|
|
|
|
openai.api_key = st.secrets["OPENAI_API_KEY"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OpenAIChatCompletions: |
|
def __init__(self, model="gpt-4", system_message=None): |
|
self.model = model |
|
self.system_message = system_message |
|
|
|
|
|
|
|
def openai_chat_completion(self, prompt, n_shot=None): |
|
messages = [{"role": "system", "content": self.system_message}] if self.system_message else [] |
|
|
|
|
|
if n_shot is not None: |
|
messages = self._add_samples(messages, n_samples=n_shot) |
|
|
|
messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
chat_request_kwargs = dict( |
|
model=self.model, |
|
messages=messages, |
|
) |
|
|
|
|
|
response = openai.ChatCompletion.create(**chat_request_kwargs) |
|
|
|
|
|
|
|
|
|
return response |
|
|
|
|
|
|
|
def predict_jsonl( |
|
self, |
|
path_or_buf='../data/cookies_train.jsonl', |
|
|
|
n_samples=None, |
|
n_shot=None |
|
): |
|
|
|
jsonObj = pd.read_json(path_or_buf=path_or_buf, lines=True) |
|
if n_samples is not None: |
|
jsonObj = jsonObj.sample(n_samples, random_state=42) |
|
|
|
iter_range = range(len(jsonObj)) |
|
prompts = [jsonObj.iloc[i]['prompt'] for i in iter_range] |
|
completions = [jsonObj.iloc[i]['completion'] for i in iter_range] |
|
predictions = [self.openai_chat_completion(prompt, n_shot=n_shot) for prompt in prompts] |
|
|
|
return prompts, completions, predictions |
|
|
|
|
|
|
|
@staticmethod |
|
def _add_samples(messages, n_samples=None): |
|
if n_samples is None: |
|
return messages |
|
|
|
samples = OpenAIChatCompletions._sample_jsonl(n_samples=n_samples) |
|
for i in range(n_samples): |
|
messages.append({"role": "user", "content": samples.iloc[i]['prompt']}) |
|
messages.append({"role": "assistant", "content": samples.iloc[i]['completion']}) |
|
|
|
return messages |
|
|
|
|
|
|
|
@staticmethod |
|
def _sample_jsonl( |
|
path_or_buf='data/cookies_train.jsonl', |
|
|
|
n_samples=5 |
|
): |
|
|
|
|
|
|
|
|
|
|
|
if "Kaleidoscope Data" in os.getcwd(): |
|
|
|
file_path = os.path.join("/".join(os.getcwd().split('/')[:-1]), path_or_buf) |
|
else: |
|
file_path = os.path.join(os.getcwd(), path_or_buf) |
|
|
|
|
|
try: |
|
with open(file_path, "r") as file: |
|
jsonl_str = file.read() |
|
|
|
jsonObj = pd.read_json(BytesIO(jsonl_str.encode()), lines=True, engine="pyarrow") |
|
except FileNotFoundError: |
|
|
|
|
|
st.write(f"File not found: {file_path}") |
|
|
|
return jsonObj.sample(n_samples, random_state=42) |
|
|