import os from typing import List, Dict import time from openai import OpenAI from assistant_file_handler import FileHandler from openai.types.beta.thread import Thread from openai.types.beta.threads.message import Message from openai.types.beta.vector_stores.vector_store_file import VectorStoreFile import structlog from openai.pagination import SyncCursorPage class OAIAssistant: def __init__(self, assistant_id, vectorstore_id) -> None: self.file_handler = FileHandler() self.assistant_id = assistant_id self.vectorstore_id = vectorstore_id self.client = OpenAI() self.openai_assistant = self.client.beta.assistants.retrieve( assistant_id=self.assistant_id ) self.log = structlog.get_logger() def create(self): pass def add_file(self, file_path: str): file_id = self.file_handler.add(file_path=file_path).id self.client.beta.vector_stores.files.create( file_id=file_id, vector_store_id=self.vectorstore_id ) def remove_file(self, file_id: str): self.client.beta.vector_stores.files.delete( file_id=file_id, vector_store_id=self.vectorstore_id ) self.log.info( f"OAIAssistant: Deleted file with id {file_id} from vector database" ) self.file_handler.remove(file_id=file_id) self.log.info(f"OAIAssistant: Deleted file with id {file_id} from file storage") def chat(self, query: str, thread_id: str): try: if not thread_id: thread = self.create_thread().id thread_id = thread.id # else: # thread_id = self.client.beta.threads.retrieve(thread_id).id self.client.beta.threads.messages.create( thread_id=thread_id, role="user", content=query, ) self.log.info( "OAIAssistant: Message added to thread", thread_id=thread_id, query=query, ) new_message, message_file_ids = self.__run_assistant(thread_id=thread_id) file_paths = [] for msg_file_id in message_file_ids: png_file_path = f"./tmp/{msg_file_id}.png" self.__convert_file_to_png( file_id=msg_file_id, write_path=png_file_path ) file_paths.append(png_file_path) file_ids = self.__add_files(file_paths=file_paths) self.client.beta.threads.messages.create( thread_id=thread_id, role="assistant", content=new_message, attachments=[ {"file_id": file_id, "tools": [{"type": "file_search"}]} for _, file_id in file_ids.items() ] if file_ids else None, ) self.log.info( "OAIAssistant: Assistant response generated", response=new_message ) return new_message except Exception as e: self.log.error("OAIAssistant: Error generating response", error=str(e)) return "OAIAssistant: An error occurred while generating the response." def create_thread(self) -> Thread: thread: Thread = self.client.beta.threads.create( tool_resources={"file_search": {"vector_store_ids": [self.vectorstore_id]}} ) return thread def delete_thread(self, thread_id: str): self.client.beta.threads.delete(thread_id=thread_id) self.log.info(f"OAIAssistant: Deleted thread with id: {thread_id}") def __convert_file_to_png(self, file_id, write_path): try: data = self.client.files.content(file_id) data_bytes = data.read() with open(write_path, "wb") as file: file.write(data_bytes) self.log.info("OAIAssistant: File converted to PNG", file_path=write_path) except Exception as e: self.log.error("OAIAssistant: Error converting file to PNG", error=str(e)) raise def __add_files(self, file_paths: List[str]) -> Dict[str, str]: try: files = {} for file in file_paths: filename = os.path.basename(file) file = self.file_handler.add(file) files[filename] = file.id self.log.info("OAIAssistant: Files added", files=files) return files except Exception as e: self.log.error("OAIAssistant: Error adding files", error=str(e)) raise def __run_assistant(self, thread_id: str): try: run = self.client.beta.threads.runs.create( thread_id=thread_id, assistant_id=self.assistant_id, ) self.log.info("OAIAssistant: Assistant run started", run_id=run.id) while run.status != "completed": time.sleep(1) run = self.client.beta.threads.runs.retrieve( thread_id=thread_id, run_id=run.id ) if run.status == "failed": self.log.error( "OAIAssistant: Assistant run failed", run_id=run.id, ) self.log.info(run) return "OAIAssistant: Error in generating response", [] messages: SyncCursorPage[Message] = self.client.beta.threads.messages.list( thread_id=thread_id, run_id=run.id ) new_message, file_ids = self.__extract_messages(messages) return new_message, file_ids except Exception as e: self.log.error("OAIAssistant: Error running assistant", error=str(e)) raise def __extract_messages(self, messages: SyncCursorPage[Message]): try: new_message = "" file_ids = [] for message in messages.data: if message.content[0].type == "text": new_message += message.content[0].text.value elif message.content[0].type == "image_file": new_message += "Image File:\n" new_message += message.content[0].image_file.file_id new_message += "\n\n" file_ids.append(message.content[0].image_file.file_id) self.log.info("OAIAssistant: Messages extracted", message=new_message) return new_message, file_ids except Exception as e: self.log.error("OAIAssistant: Error extracting messages", error=str(e)) raise def get_files_list(self): files: SyncCursorPage[VectorStoreFile] = ( self.client.beta.vector_stores.files.list( vector_store_id=self.vectorstore_id ) ) return [file.id for file in files]