File size: 2,719 Bytes
ef3d4ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a70816
ef3d4ad
8336dcb
2a70816
 
ef3d4ad
 
 
8336dcb
ef3d4ad
416107c
8336dcb
ef3d4ad
2a70816
8336dcb
ef3d4ad
 
 
 
 
 
8336dcb
ef3d4ad
 
 
 
416107c
ef3d4ad
 
 
 
 
 
416107c
ef3d4ad
 
 
 
 
 
 
 
e946c57
 
 
 
 
ef3d4ad
 
 
 
 
 
 
8336dcb
ef3d4ad
 
 
2a70816
ef3d4ad
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import math
import time

import pandas as pd

from ._logging import Logger

def parse_wait_time(err):
    if err.code == 'rate_limit_exceeded':
        for i in err.message.split('. '):
            if i.startswith('Please try again in'):
                (*_, wait) = i.split()
                return (pd
                        .to_timedelta(wait)
                        .total_seconds())

    raise TypeError(err.code)

class ChatController:
    _assistant_kwargs = {
        'model': 'gpt-4o',
        'temperature': 1e-4,
    }
    _threads_kwargs = {
        'max_completion_tokens': 2 ** 12,
    }

    def __init__(self, client, database, instructions, retries=10, **kwargs):
        self.client = client
        self.database = database
        self.retries = retries

        for i in self._assistant_kwargs.items():
            kwargs.setdefault(*i)

        self.assistant = self.client.beta.assistants.create(
            instructions=instructions.read_text(),
            tools=[{
                'type': 'file_search',
            }],
            **kwargs,
        )
        self.thread = self.client.beta.threads.create()
        self.attached = False

    def __call__(self, prompt):
        if not self.attached:
            self.client.beta.assistants.update(
                assistant_id=self.assistant.id,
                tool_resources={
                    'file_search': {
                        'vector_store_ids': [
                            self.database.vector_store_id,
                        ],
                    },
                },
            )
            self.attached = True

        return self.send(prompt)

    def cleanup(self):
        self.client.beta.threads.delete(self.thread.id)
        self.client.beta.assistants.delete(self.assistant.id)
        self.attached = False

    def send(self, content):
        self.client.beta.threads.messages.create(
            self.thread.id,
            role='user',
            content=content,
        )

        for i in range(self.retries):
            run = self.client.beta.threads.runs.create_and_poll(
                thread_id=self.thread.id,
                assistant_id=self.assistant.id,
                **self._threads_kwargs,
            )
            if run.status == 'completed':
                return self.client.beta.threads.messages.list(
                    thread_id=self.thread.id,
                    run_id=run.id,
                )
            Logger.error('%s (%d): %s', run.status, i + 1, run.last_error)

            rest = math.ceil(parse_wait_time(run.last_error))
            Logger.warning('Sleeping %ds', rest)
            time.sleep(rest)

        raise TimeoutError('Message retries exceeded')