Spaces:
Runtime error
Runtime error
import math | |
import time | |
import pandas as pd | |
from ._logging import Logger | |
def parse_wait_time(err): | |
if err.code == 'rate_limit_exceeded': | |
for i in err.message.split('. '): | |
if i.startswith('Please try again in'): | |
(*_, wait) = i.split() | |
return (pd | |
.to_timedelta(wait) | |
.total_seconds()) | |
raise TypeError(err.code) | |
class ChatController: | |
_assistant_kwargs = { | |
'model': 'gpt-4o', | |
'temperature': 1e-4, | |
} | |
_threads_kwargs = { | |
'max_completion_tokens': 2 ** 12, | |
} | |
def __init__(self, client, database, instructions, retries=10, **kwargs): | |
self.client = client | |
self.database = database | |
self.retries = retries | |
for i in self._assistant_kwargs.items(): | |
kwargs.setdefault(*i) | |
self.assistant = self.client.beta.assistants.create( | |
instructions=instructions.read_text(), | |
tools=[{ | |
'type': 'file_search', | |
}], | |
**kwargs, | |
) | |
self.thread = self.client.beta.threads.create() | |
self.attached = False | |
def __call__(self, prompt): | |
if not self.attached: | |
self.client.beta.assistants.update( | |
assistant_id=self.assistant.id, | |
tool_resources={ | |
'file_search': { | |
'vector_store_ids': [ | |
self.database.vector_store_id, | |
], | |
}, | |
}, | |
) | |
self.attached = True | |
return self.send(prompt) | |
def cleanup(self): | |
self.client.beta.threads.delete(self.thread.id) | |
self.client.beta.assistants.delete(self.assistant.id) | |
self.attached = False | |
def send(self, content): | |
self.client.beta.threads.messages.create( | |
self.thread.id, | |
role='user', | |
content=content, | |
) | |
for i in range(self.retries): | |
run = self.client.beta.threads.runs.create_and_poll( | |
thread_id=self.thread.id, | |
assistant_id=self.assistant.id, | |
**self._threads_kwargs, | |
) | |
if run.status == 'completed': | |
return self.client.beta.threads.messages.list( | |
thread_id=self.thread.id, | |
run_id=run.id, | |
) | |
Logger.error('%s (%d): %s', run.status, i + 1, run.last_error) | |
rest = math.ceil(parse_wait_time(run.last_error)) | |
Logger.warning('Sleeping %ds', rest) | |
time.sleep(rest) | |
raise TimeoutError('Message retries exceeded') | |