File size: 7,338 Bytes
a3ffd31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
"""
This module is responsible for the VectorDB API. It currently supports:
* DELETE api/v1/clear
    - Clears the whole DB.
* POST api/v1/add
    - Add some corpus to the DB. You can also specify metadata to be added alongside it.
* POST api/v1/delete
    - Delete specific records with given metadata.
* POST api/v1/get
    - Get results from chromaDB.
"""

import json
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from urllib.parse import urlparse, parse_qs
from threading import Thread

from modules import shared
from modules.logging_colors import logger

from .chromadb import ChromaCollector
from .data_processor import process_and_add_to_collector

import extensions.superboogav2.parameters as parameters


class CustomThreadingHTTPServer(ThreadingHTTPServer):
    def __init__(self, server_address, RequestHandlerClass, collector: ChromaCollector, bind_and_activate=True):
        self.collector = collector
        super().__init__(server_address, RequestHandlerClass, bind_and_activate)

    def finish_request(self, request, client_address):
        self.RequestHandlerClass(request, client_address, self, self.collector)


class Handler(BaseHTTPRequestHandler):
    def __init__(self, request, client_address, server, collector: ChromaCollector):
        self.collector = collector
        super().__init__(request, client_address, server)


    def _send_412_error(self, message):
        self.send_response(412)
        self.send_header("Content-type", "application/json")
        self.end_headers()
        response = json.dumps({"error": message})
        self.wfile.write(response.encode('utf-8'))


    def _send_404_error(self):
        self.send_response(404)
        self.send_header("Content-type", "application/json")
        self.end_headers()
        response = json.dumps({"error": "Resource not found"})
        self.wfile.write(response.encode('utf-8'))


    def _send_400_error(self, error_message: str):
        self.send_response(400)
        self.send_header("Content-type", "application/json")
        self.end_headers()
        response = json.dumps({"error": error_message})
        self.wfile.write(response.encode('utf-8'))
        

    def _send_200_response(self, message: str):
        self.send_response(200)
        self.send_header("Content-type", "application/json")
        self.end_headers()

        if isinstance(message, str):
            response = json.dumps({"message": message})
        else:
            response = json.dumps(message)

        self.wfile.write(response.encode('utf-8'))


    def _handle_get(self, search_strings: list[str], n_results: int, max_token_count: int, sort_param: str):
        if sort_param == parameters.SORT_DISTANCE:
            results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count)
        elif sort_param == parameters.SORT_ID:
            results = self.collector.get_sorted_by_id(search_strings, n_results, max_token_count)
        else: # Default is dist
            results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count)
        
        return {
            "results": results
        }

        
    def do_GET(self):
        self._send_404_error()


    def do_POST(self):
        try:
            content_length = int(self.headers['Content-Length'])
            body = json.loads(self.rfile.read(content_length).decode('utf-8'))

            parsed_path = urlparse(self.path)
            path = parsed_path.path
            query_params = parse_qs(parsed_path.query)

            if path in ['/api/v1/add', '/api/add']:
                corpus = body.get('corpus')
                if corpus is None:
                    self._send_412_error("Missing parameter 'corpus'")
                    return
                
                clear_before_adding = body.get('clear_before_adding', False)
                metadata = body.get('metadata')
                process_and_add_to_collector(corpus, self.collector, clear_before_adding, metadata)
                self._send_200_response("Data successfully added")

            elif path in ['/api/v1/delete', '/api/delete']:
                metadata = body.get('metadata')
                if corpus is None:
                    self._send_412_error("Missing parameter 'metadata'")
                    return
                
                self.collector.delete(ids_to_delete=None, where=metadata)
                self._send_200_response("Data successfully deleted")

            elif path in ['/api/v1/get', '/api/get']:
                search_strings = body.get('search_strings')
                if search_strings is None:
                    self._send_412_error("Missing parameter 'search_strings'")
                    return
                
                n_results = body.get('n_results')
                if n_results is None:
                    n_results = parameters.get_chunk_count()
                
                max_token_count = body.get('max_token_count')
                if max_token_count is None:
                    max_token_count = parameters.get_max_token_count()
                
                sort_param = query_params.get('sort', ['distance'])[0]

                results = self._handle_get(search_strings, n_results, max_token_count, sort_param)
                self._send_200_response(results)

            else:
                self._send_404_error()
        except Exception as e:
            self._send_400_error(str(e))


    def do_DELETE(self):
        try:
            parsed_path = urlparse(self.path)
            path = parsed_path.path
            query_params = parse_qs(parsed_path.query)

            if path in ['/api/v1/clear', '/api/clear']:
                self.collector.clear()
                self._send_200_response("Data successfully cleared")
            else:
                self._send_404_error()
        except Exception as e:
            self._send_400_error(str(e))


    def do_OPTIONS(self):
        self.send_response(200)
        self.end_headers()


    def end_headers(self):
        self.send_header('Access-Control-Allow-Origin', '*')
        self.send_header('Access-Control-Allow-Methods', '*')
        self.send_header('Access-Control-Allow-Headers', '*')
        self.send_header('Cache-Control', 'no-store, no-cache, must-revalidate')
        super().end_headers()


class APIManager:
    def __init__(self, collector: ChromaCollector):
        self.server = None
        self.collector = collector
        self.is_running = False

    def start_server(self, port: int):
        if self.server is not None:
            print("Server already running.")
            return

        address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
        self.server = CustomThreadingHTTPServer((address, port), Handler, self.collector)

        logger.info(f'Starting chromaDB API at http://{address}:{port}/api')

        Thread(target=self.server.serve_forever, daemon=True).start()

        self.is_running = True

    def stop_server(self):
        if self.server is not None:
            logger.info(f'Stopping chromaDB API.')
            self.server.shutdown()
            self.server.server_close()
            self.server = None
            self.is_running = False

    def is_server_running(self):
        return self.is_running