File size: 6,756 Bytes
68515e1
 
 
 
 
 
778d7a6
68515e1
 
 
a487faf
68515e1
 
778d7a6
68515e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a487faf
 
 
 
 
 
 
 
 
 
 
 
 
 
68515e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a487faf
68515e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a487faf
68515e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a487faf
68515e1
 
 
 
 
a487faf
68515e1
 
 
 
 
 
 
 
 
a487faf
68515e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import os
import uvicorn
import time
from fastapi import FastAPI, Request, HTTPException
import tiktoken
from json.decoder import JSONDecodeError
import multiprocessing

# 初始化 FastAPI
app = FastAPI()

# 获取编码器
encoding = tiktoken.get_encoding("cl100k_base")
debug=True

# 随机生成 ID
def general_id(zimu=4, num=6):
    import random
    import string
    letters = ''.join(random.choices(string.ascii_lowercase, k=zimu))
    numbers = ''.join(random.choices(string.digits, k=num))
    return f"chatcmpl-{letters}{numbers}"


# 异常处理函数
def handle_exception(e):
    if isinstance(e, JSONDecodeError):
        raise HTTPException(status_code=400, detail="无效的JSON格式")
    else:
        raise HTTPException(status_code=500, detail="服务器内部错误")


# chatgpt format
def is_chatgpt_format(data):
    try:
        # 检查数据是否符合 ChatGPT 的格式
        if isinstance(data, dict) and "choices" in data and "message" in data["choices"][0]:
            return True
    except Exception as e:
        pass
    return False

def get_workers_count() -> int:
    """
    Calculate optimal number of workers
    Default: 4, Maximum: 8
    Formula: min(max(4, (2 * CPU cores) + 1), 8)
    """
    try:
        cpu_cores = multiprocessing.cpu_count()
        recommended_workers = (2 * cpu_cores) + 1
        return min(max(4, recommended_workers), 8)
    except Exception as e:
        if debug:
            print(f"Worker count calculation failed: {e}, using default 4")
        return 4

# 模拟 ChatGPT 响应
def generate_response(headers, data):
    try:
        ## 尝试解析头及请求,主服务未来不是授权key不能使用

        ## 返回结果
        result = "This is a test result."


        ## 计算tokean
        # 判断请求体中的数据是否已经符合 ChatGPT 的格式
        if is_chatgpt_format(result):
            response_data = result
        else:
            # 计算时间戳
            current_timestamp = int(time.time() * 1000)

            # 计算 token 数量
            prompt_tokens = len(encoding.encode(str(data)))  # 对请求体的编码进行 token 计算
            completion_tokens = len(encoding.encode(result))  # 对返回结果的编码进行 token 计算
            total_tokens = prompt_tokens + completion_tokens

            # 构造符合 ChatGPT 格式的响应
            response_data = {
                "id": general_id(),
                "object": "chat.completion",
                "created": current_timestamp,
                "model": data.get("model", "gpt-4o"),
                "usage": {
                    "prompt_tokens": prompt_tokens,
                    "completion_tokens": completion_tokens,
                    "total_tokens": total_tokens,
                    "completion_tokens_details": {
                        "reasoning_tokens": 0,
                        "accepted_prediction_tokens": 0,
                        "rejected_prediction_tokens": 0
                    }
                },
                "choices": [
                    {
                        "message": {
                            "role": "assistant",
                            "content": result
                        },
                        "logprobs": None,
                        "finish_reason": "stop",
                        "index": 0
                    }
                ]
            }
        # 打印最终返回的数据
        print(f"Response Data: {response_data}" )
        return response_data
    except Exception as e:
        handle_exception(e)  # 异常处理函数


# 解析动态路由

def get_dynamic_routes():
    """
    根据环境变量动态生成路由路径。

    优先级: REPLACE_CHAT > PREFIX_CHAT > APPEND_CHAT
    - REPLACE_CHAT: 替换默认访问地址,直接替换默认路径,支持多个,逗号分割
    - PREFIX_CHAT: 在默认访问地址前加上前缀,支持多个,逗号分割
    - APPEND_CHAT: 添加额外的访问地址,支持多个,逗号分割

    :return: 返回一个包含动态生成的路由路径的列表
    """
    # 默认路径
    default_path = "/v1/chat/completions"

    # 获取环境变量,默认为空字符串
    replace_chat = os.getenv("REPLACE_CHAT", "")
    prefix_chat = os.getenv("PREFIX_CHAT", "")
    append_chat = os.getenv("APPEND_CHAT", "")

    # 优先级: 如果设置了 REPLACE_CHAT,直接返回替换路径
    if replace_chat:
        # 如果设置了 REPLACE_CHAT,按逗号分割,返回多个替换路径
        return [path.strip() for path in replace_chat.split(",") if path.strip()]

    # 如果没有 REPLACE_CHAT,检查是否设置了 PREFIX_CHAT
    routes = []  # 用于保存最终生成的路由列表

    # 处理 PREFIX_CHAT: 在默认路径前加前缀
    if prefix_chat:
        # 将 PREFIX_CHAT 按逗号分隔,支持多个前缀
        prefixes = prefix_chat.split(",")
        for prefix in prefixes:
            # 确保路径拼接时,前缀和默认路径之间有一个斜杠
            routes.append(prefix + default_path)
        # 如果设置了 PREFIX_CHAT,返回所有生成的前缀路径
        return routes

    # 如果没有设置 PREFIX_CHAT,检查是否设置了 APPEND_CHAT
    if append_chat:
        # 将 APPEND_CHAT 按逗号分隔,去除可能的空值和空白字符
        append_paths = [path.strip() for path in append_chat.split(",") if path.strip()]
        # 将默认路径和 APPEND_CHAT 中的路径合并
        routes = [default_path] + append_paths

    # 如果没有设置 REPLACE_CHAT、PREFIX_CHAT 或 APPEND_CHAT,返回默认路径
    if not routes:
        routes.append(default_path)

    # 返回生成的所有路由路径
    return routes


# 注册单个动态路由
def register_route(path: str):
    print(f"register  route path: {path}")
    @app.post(path)
    async def dynamic_chat_endpoint(request: Request):
        try:
            headers = request.headers
            data = request.json()
            print(f"Received Request Header: {headers}\nData: {data}")
            result = generate_response(headers, data)
            return result
        except Exception as e:
            raise HTTPException(status_code=500, detail=str(e))


# 动态设置路由
def setup_dynamic_routes():
    routes = get_dynamic_routes()
    print("Registering routes: {routes}")
    for path in routes:
        register_route(path)


@app.get("/")
def read_root():
    return "working..."


if __name__ == "__main__":
    # 注册动态路由
    setup_dynamic_routes()
    # 获取端口环境变量(默认为 7860)
    port = int(os.getenv("PORT", "7860"))
    uvicorn.run(app, host="0.0.0.0", port=port)