File size: 6,972 Bytes
87337b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
from ten import (
    Extension,
    TenEnv,
    Cmd,
    StatusCode,
    CmdResult,
)

import json
from typing import Generator, List
from http import HTTPStatus
import threading, queue
from datetime import datetime

CMD_EMBED = "embed"
CMD_EMBED_BATCH = "embed_batch"

FIELD_KEY_EMBEDDING = "embedding"
FIELD_KEY_EMBEDDINGS = "embeddings"
FIELD_KEY_MESSAGE = "message"
FIELD_KEY_CODE = "code"

DASHSCOPE_MAX_BATCH_SIZE = 6


class EmbeddingExtension(Extension):
    def __init__(self, name: str):
        super().__init__(name)
        self.api_key = ""
        self.model = ""

        self.stop = False
        self.queue = queue.Queue()
        self.threads = []

        # workaround to speed up the embedding process,
        # should be replace by https://help.aliyun.com/zh/model-studio/developer-reference/text-embedding-batch-api?spm=a2c4g.11186623.0.0.24cb7453KSjdhC
        # once v3 models supported
        self.parallel = 10

    def on_start(self, ten: TenEnv) -> None:
        ten.log_info("on_start")
        self.api_key = self.get_property_string(ten, "api_key", self.api_key)
        self.model = self.get_property_string(ten, "model", self.api_key)

        # lazy import packages which requires long time to load
        global dashscope  # pylint: disable=global-statement
        import dashscope

        dashscope.api_key = self.api_key

        for i in range(self.parallel):
            thread = threading.Thread(target=self.async_handler, args=[i, ten])
            thread.start()
            self.threads.append(thread)

        ten.on_start_done()

    def async_handler(self, index: int, ten: TenEnv):
        ten.log_info(f"async_handler {index} statend")

        while not self.stop:
            cmd = self.queue.get()
            if cmd is None:
                break

            cmd_name = cmd.get_name()
            start_time = datetime.now()
            ten.log_info(f"async_handler {index} processing cmd {cmd_name}")

            if cmd_name == CMD_EMBED:
                cmd_result = self.call_with_str(cmd.get_property_string("input"), ten)
                ten.return_result(cmd_result, cmd)
            elif cmd_name == CMD_EMBED_BATCH:
                inputs_list = json.loads(cmd.get_property_to_json("inputs"))
                cmd_result = self.call_with_strs(inputs_list, ten)
                ten.return_result(cmd_result, cmd)
            else:
                ten.log_warn("unknown cmd {cmd_name}")

            ten.log_info(
                f"async_handler {index} finished processing cmd {cmd_name}, cost {int((datetime.now() - start_time).total_seconds() * 1000)}ms"
            )

        ten.log_info(f"async_handler {index} stopped")

    def call_with_str(self, message: str, ten: TenEnv) -> CmdResult:
        start_time = datetime.now()
        # pylint: disable=undefined-variable
        response = dashscope.TextEmbedding.call(model=self.model, input=message)
        ten.log_info(
            f"embedding call finished for input [{message}], status_code {response.status_code}, cost {int((datetime.now() - start_time).total_seconds() * 1000)}ms"
        )

        if response.status_code == HTTPStatus.OK:
            cmd_result = CmdResult.create(StatusCode.OK)
            cmd_result.set_property_from_json(
                FIELD_KEY_EMBEDDING,
                json.dumps(response.output["embeddings"][0]["embedding"]),
            )
            return cmd_result
        else:
            cmd_result = CmdResult.create(StatusCode.ERROR)
            cmd_result.set_property_string(FIELD_KEY_CODE, response.status_code)
            cmd_result.set_property_string(FIELD_KEY_MESSAGE, response.message)
            return cmd_result

    def batched(
        self, inputs: List, batch_size: int = DASHSCOPE_MAX_BATCH_SIZE
    ) -> Generator[List, None, None]:
        for i in range(0, len(inputs), batch_size):
            yield inputs[i : i + batch_size]

    def call_with_strs(self, messages: List[str], ten: TenEnv) -> CmdResult:
        start_time = datetime.now()
        result = None  # merge the results.
        batch_counter = 0
        for batch in self.batched(messages):
            # pylint: disable=undefined-variable
            response = dashscope.TextEmbedding.call(model=self.model, input=batch)
            # ten.log_info("%s Received %s", batch, response)
            if response.status_code == HTTPStatus.OK:
                if result is None:
                    result = response.output
                else:
                    for emb in response.output["embeddings"]:
                        emb["text_index"] += batch_counter
                        result["embeddings"].append(emb)
            else:
                ten.log_error("call %s failed, errmsg: %s", batch, response)
            batch_counter += len(batch)

        ten.log_info(
            f"embedding call finished for inputs len {len(messages)}, batch_counter {batch_counter}, results len {len(result['embeddings'])}, cost {int((datetime.now() - start_time).total_seconds() * 1000)}ms "
        )
        if result is not None:
            cmd_result = CmdResult.create(StatusCode.OK)

            # too slow `set_property_to_json`, so use `set_property_string` at the moment as workaround
            # will be replaced once `set_property_to_json` improved
            cmd_result.set_property_string(
                FIELD_KEY_EMBEDDINGS, json.dumps(result["embeddings"])
            )
            return cmd_result
        else:
            cmd_result = CmdResult.create(StatusCode.ERROR)
            cmd_result.set_property_string(FIELD_KEY_MESSAGE, "All batch failed")
            ten.log_error("All batch failed")
            return cmd_result

    def on_stop(self, ten: TenEnv) -> None:
        ten.log_info("on_stop")
        self.stop = True
        # clear queue
        while not self.queue.empty():
            self.queue.get()
        # put enough None to stop all threads
        for thread in self.threads:
            self.queue.put(None)
        for thread in self.threads:
            thread.join()
        self.threads = []

        ten.on_stop_done()

    def on_cmd(self, ten: TenEnv, cmd: Cmd) -> None:
        cmd_name = cmd.get_name()

        if cmd_name in [CMD_EMBED, CMD_EMBED_BATCH]:
            # // embed
            # {
            #     "name": "embed",
            #     "input": "hello"
            # }

            # // embed_batch
            # {
            #     "name": "embed_batch",
            #     "inputs": ["hello", ...]
            # }

            self.queue.put(cmd)
        else:
            ten.log_warn(f"unknown cmd {cmd_name}")
            cmd_result = CmdResult.create(StatusCode.ERROR)
            ten.return_result(cmd_result, cmd)

    def get_property_string(self, ten: TenEnv, key, default):
        try:
            return ten.get_property_string(key)
        except Exception as e:
            ten.log_warn(f"err: {e}")
            return default