import math import time import pandas as pd from ._logging import Logger def parse_wait_time(err): if err.code == 'rate_limit_exceeded': for i in err.message.split('. '): if i.startswith('Please try again in'): (*_, wait) = i.split() return (pd .to_timedelta(wait) .total_seconds()) raise TypeError(err.code) class ChatController: _assistant_kwargs = { 'model': 'gpt-4o', 'temperature': 1e-4, } _threads_kwargs = { 'max_completion_tokens': 2 ** 12, } def __init__(self, client, database, instructions, retries=10, **kwargs): self.client = client self.database = database self.retries = retries for i in self._assistant_kwargs.items(): kwargs.setdefault(*i) self.assistant = self.client.beta.assistants.create( instructions=instructions.read_text(), tools=[{ 'type': 'file_search', }], **kwargs, ) self.thread = self.client.beta.threads.create() self.attached = False def __call__(self, prompt): if not self.attached: self.client.beta.assistants.update( assistant_id=self.assistant.id, tool_resources={ 'file_search': { 'vector_store_ids': [ self.database.vector_store_id, ], }, }, ) self.attached = True return self.send(prompt) def cleanup(self): self.client.beta.threads.delete(self.thread.id) self.client.beta.assistants.delete(self.assistant.id) self.attached = False def send(self, content): self.client.beta.threads.messages.create( self.thread.id, role='user', content=content, ) for i in range(self.retries): run = self.client.beta.threads.runs.create_and_poll( thread_id=self.thread.id, assistant_id=self.assistant.id, **self._threads_kwargs, ) if run.status == 'completed': return self.client.beta.threads.messages.list( thread_id=self.thread.id, run_id=run.id, ) Logger.error('%s (%d): %s', run.status, i + 1, run.last_error) rest = math.ceil(parse_wait_time(run.last_error)) Logger.warning('Sleeping %ds', rest) time.sleep(rest) raise TimeoutError('Message retries exceeded')