File size: 6,972 Bytes
87337b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
from ten import (
Extension,
TenEnv,
Cmd,
StatusCode,
CmdResult,
)
import json
from typing import Generator, List
from http import HTTPStatus
import threading, queue
from datetime import datetime
CMD_EMBED = "embed"
CMD_EMBED_BATCH = "embed_batch"
FIELD_KEY_EMBEDDING = "embedding"
FIELD_KEY_EMBEDDINGS = "embeddings"
FIELD_KEY_MESSAGE = "message"
FIELD_KEY_CODE = "code"
DASHSCOPE_MAX_BATCH_SIZE = 6
class EmbeddingExtension(Extension):
def __init__(self, name: str):
super().__init__(name)
self.api_key = ""
self.model = ""
self.stop = False
self.queue = queue.Queue()
self.threads = []
# workaround to speed up the embedding process,
# should be replace by https://help.aliyun.com/zh/model-studio/developer-reference/text-embedding-batch-api?spm=a2c4g.11186623.0.0.24cb7453KSjdhC
# once v3 models supported
self.parallel = 10
def on_start(self, ten: TenEnv) -> None:
ten.log_info("on_start")
self.api_key = self.get_property_string(ten, "api_key", self.api_key)
self.model = self.get_property_string(ten, "model", self.api_key)
# lazy import packages which requires long time to load
global dashscope # pylint: disable=global-statement
import dashscope
dashscope.api_key = self.api_key
for i in range(self.parallel):
thread = threading.Thread(target=self.async_handler, args=[i, ten])
thread.start()
self.threads.append(thread)
ten.on_start_done()
def async_handler(self, index: int, ten: TenEnv):
ten.log_info(f"async_handler {index} statend")
while not self.stop:
cmd = self.queue.get()
if cmd is None:
break
cmd_name = cmd.get_name()
start_time = datetime.now()
ten.log_info(f"async_handler {index} processing cmd {cmd_name}")
if cmd_name == CMD_EMBED:
cmd_result = self.call_with_str(cmd.get_property_string("input"), ten)
ten.return_result(cmd_result, cmd)
elif cmd_name == CMD_EMBED_BATCH:
inputs_list = json.loads(cmd.get_property_to_json("inputs"))
cmd_result = self.call_with_strs(inputs_list, ten)
ten.return_result(cmd_result, cmd)
else:
ten.log_warn("unknown cmd {cmd_name}")
ten.log_info(
f"async_handler {index} finished processing cmd {cmd_name}, cost {int((datetime.now() - start_time).total_seconds() * 1000)}ms"
)
ten.log_info(f"async_handler {index} stopped")
def call_with_str(self, message: str, ten: TenEnv) -> CmdResult:
start_time = datetime.now()
# pylint: disable=undefined-variable
response = dashscope.TextEmbedding.call(model=self.model, input=message)
ten.log_info(
f"embedding call finished for input [{message}], status_code {response.status_code}, cost {int((datetime.now() - start_time).total_seconds() * 1000)}ms"
)
if response.status_code == HTTPStatus.OK:
cmd_result = CmdResult.create(StatusCode.OK)
cmd_result.set_property_from_json(
FIELD_KEY_EMBEDDING,
json.dumps(response.output["embeddings"][0]["embedding"]),
)
return cmd_result
else:
cmd_result = CmdResult.create(StatusCode.ERROR)
cmd_result.set_property_string(FIELD_KEY_CODE, response.status_code)
cmd_result.set_property_string(FIELD_KEY_MESSAGE, response.message)
return cmd_result
def batched(
self, inputs: List, batch_size: int = DASHSCOPE_MAX_BATCH_SIZE
) -> Generator[List, None, None]:
for i in range(0, len(inputs), batch_size):
yield inputs[i : i + batch_size]
def call_with_strs(self, messages: List[str], ten: TenEnv) -> CmdResult:
start_time = datetime.now()
result = None # merge the results.
batch_counter = 0
for batch in self.batched(messages):
# pylint: disable=undefined-variable
response = dashscope.TextEmbedding.call(model=self.model, input=batch)
# ten.log_info("%s Received %s", batch, response)
if response.status_code == HTTPStatus.OK:
if result is None:
result = response.output
else:
for emb in response.output["embeddings"]:
emb["text_index"] += batch_counter
result["embeddings"].append(emb)
else:
ten.log_error("call %s failed, errmsg: %s", batch, response)
batch_counter += len(batch)
ten.log_info(
f"embedding call finished for inputs len {len(messages)}, batch_counter {batch_counter}, results len {len(result['embeddings'])}, cost {int((datetime.now() - start_time).total_seconds() * 1000)}ms "
)
if result is not None:
cmd_result = CmdResult.create(StatusCode.OK)
# too slow `set_property_to_json`, so use `set_property_string` at the moment as workaround
# will be replaced once `set_property_to_json` improved
cmd_result.set_property_string(
FIELD_KEY_EMBEDDINGS, json.dumps(result["embeddings"])
)
return cmd_result
else:
cmd_result = CmdResult.create(StatusCode.ERROR)
cmd_result.set_property_string(FIELD_KEY_MESSAGE, "All batch failed")
ten.log_error("All batch failed")
return cmd_result
def on_stop(self, ten: TenEnv) -> None:
ten.log_info("on_stop")
self.stop = True
# clear queue
while not self.queue.empty():
self.queue.get()
# put enough None to stop all threads
for thread in self.threads:
self.queue.put(None)
for thread in self.threads:
thread.join()
self.threads = []
ten.on_stop_done()
def on_cmd(self, ten: TenEnv, cmd: Cmd) -> None:
cmd_name = cmd.get_name()
if cmd_name in [CMD_EMBED, CMD_EMBED_BATCH]:
# // embed
# {
# "name": "embed",
# "input": "hello"
# }
# // embed_batch
# {
# "name": "embed_batch",
# "inputs": ["hello", ...]
# }
self.queue.put(cmd)
else:
ten.log_warn(f"unknown cmd {cmd_name}")
cmd_result = CmdResult.create(StatusCode.ERROR)
ten.return_result(cmd_result, cmd)
def get_property_string(self, ten: TenEnv, key, default):
try:
return ten.get_property_string(key)
except Exception as e:
ten.log_warn(f"err: {e}")
return default
|