dairy-reports-analysis / assistant.py
umairahmad89
Add list files function
ebe102a
raw
history blame
7.02 kB
import os
from typing import List, Dict
import time
from openai import OpenAI
from assistant_file_handler import FileHandler
from openai.types.beta.thread import Thread
from openai.types.beta.threads.message import Message
from openai.types.beta.vector_stores.vector_store_file import VectorStoreFile
import structlog
from openai.pagination import SyncCursorPage
class OAIAssistant:
def __init__(self, assistant_id, vectorstore_id) -> None:
self.file_handler = FileHandler()
self.assistant_id = assistant_id
self.vectorstore_id = vectorstore_id
self.client = OpenAI()
self.openai_assistant = self.client.beta.assistants.retrieve(
assistant_id=self.assistant_id
)
self.log = structlog.get_logger()
def create(self):
pass
def add_file(self, file_path: str):
file_id = self.file_handler.add(file_path=file_path).id
self.client.beta.vector_stores.files.create(
file_id=file_id, vector_store_id=self.vectorstore_id
)
def remove_file(self, file_id: str):
self.client.beta.vector_stores.files.delete(
file_id=file_id, vector_store_id=self.vectorstore_id
)
self.log.info(
f"OAIAssistant: Deleted file with id {file_id} from vector database"
)
self.file_handler.remove(file_id=file_id)
self.log.info(f"OAIAssistant: Deleted file with id {file_id} from file storage")
def chat(self, query: str, thread_id: str):
try:
if not thread_id:
thread = self.create_thread().id
thread_id = thread.id
# else:
# thread_id = self.client.beta.threads.retrieve(thread_id).id
self.client.beta.threads.messages.create(
thread_id=thread_id,
role="user",
content=query,
)
self.log.info(
"OAIAssistant: Message added to thread",
thread_id=thread_id,
query=query,
)
new_message, message_file_ids = self.__run_assistant(thread_id=thread_id)
file_paths = []
for msg_file_id in message_file_ids:
png_file_path = f"./tmp/{msg_file_id}.png"
self.__convert_file_to_png(
file_id=msg_file_id, write_path=png_file_path
)
file_paths.append(png_file_path)
file_ids = self.__add_files(file_paths=file_paths)
self.client.beta.threads.messages.create(
thread_id=thread_id,
role="assistant",
content=new_message,
attachments=[
{"file_id": file_id, "tools": [{"type": "file_search"}]}
for _, file_id in file_ids.items()
]
if file_ids
else None,
)
self.log.info(
"OAIAssistant: Assistant response generated", response=new_message
)
return new_message
except Exception as e:
self.log.error("OAIAssistant: Error generating response", error=str(e))
return "OAIAssistant: An error occurred while generating the response."
def create_thread(self) -> Thread:
thread: Thread = self.client.beta.threads.create(
tool_resources={"file_search": {"vector_store_ids": [self.vectorstore_id]}}
)
return thread
def delete_thread(self, thread_id: str):
self.client.beta.threads.delete(thread_id=thread_id)
self.log.info(f"OAIAssistant: Deleted thread with id: {thread_id}")
def __convert_file_to_png(self, file_id, write_path):
try:
data = self.client.files.content(file_id)
data_bytes = data.read()
with open(write_path, "wb") as file:
file.write(data_bytes)
self.log.info("OAIAssistant: File converted to PNG", file_path=write_path)
except Exception as e:
self.log.error("OAIAssistant: Error converting file to PNG", error=str(e))
raise
def __add_files(self, file_paths: List[str]) -> Dict[str, str]:
try:
files = {}
for file in file_paths:
filename = os.path.basename(file)
file = self.file_handler.add(file)
files[filename] = file.id
self.log.info("OAIAssistant: Files added", files=files)
return files
except Exception as e:
self.log.error("OAIAssistant: Error adding files", error=str(e))
raise
def __run_assistant(self, thread_id: str):
try:
run = self.client.beta.threads.runs.create(
thread_id=thread_id,
assistant_id=self.assistant_id,
)
self.log.info("OAIAssistant: Assistant run started", run_id=run.id)
while run.status != "completed":
time.sleep(1)
run = self.client.beta.threads.runs.retrieve(
thread_id=thread_id, run_id=run.id
)
if run.status == "failed":
self.log.error(
"OAIAssistant: Assistant run failed",
run_id=run.id,
)
self.log.info(run)
return "OAIAssistant: Error in generating response", []
messages: SyncCursorPage[Message] = self.client.beta.threads.messages.list(
thread_id=thread_id, run_id=run.id
)
new_message, file_ids = self.__extract_messages(messages)
return new_message, file_ids
except Exception as e:
self.log.error("OAIAssistant: Error running assistant", error=str(e))
raise
def __extract_messages(self, messages: SyncCursorPage[Message]):
try:
new_message = ""
file_ids = []
for message in messages.data:
if message.content[0].type == "text":
new_message += message.content[0].text.value
elif message.content[0].type == "image_file":
new_message += "Image File:\n"
new_message += message.content[0].image_file.file_id
new_message += "\n\n"
file_ids.append(message.content[0].image_file.file_id)
self.log.info("OAIAssistant: Messages extracted", message=new_message)
return new_message, file_ids
except Exception as e:
self.log.error("OAIAssistant: Error extracting messages", error=str(e))
raise
def get_files_list(self):
files: SyncCursorPage[VectorStoreFile] = (
self.client.beta.vector_stores.files.list(
vector_store_id=self.vectorstore_id
)
)
return [file.id for file in files]