yym68686 commited on
Commit
f156f8a
·
1 Parent(s): 2797497

🐛 Bug: 1. Fix the bug that causes an error when Claude uploads a PNG image.

Browse files

2. Fix the bug where fields are not automatically added when the database does not have specific fields.

✨ Feature: Add user message ethics review support.

📖 Docs: Update documentation

Files changed (5) hide show
  1. README_CN.md +2 -0
  2. main.py +100 -13
  3. models.py +12 -1
  4. request.py +36 -7
  5. utils.py +2 -2
README_CN.md CHANGED
@@ -105,10 +105,12 @@ api_keys:
105
  model:
106
  - anthropic/claude-3-5-sonnet # 可以使用的模型名称,仅可以使用名为 anthropic 提供商提供的 claude-3-5-sonnet 模型。其他提供商的 claude-3-5-sonnet 模型不可以使用。这种写法不会匹配到other-provider提供的名为anthropic/claude-3-5-sonnet的模型。
107
  - <anthropic/claude-3-5-sonnet> # 通过在模型名两侧加上尖括号,这样就不会去名为anthropic的渠道下去寻找claude-3-5-sonnet模型,而是将整个 anthropic/claude-3-5-sonnet 作为模型名称。这种写法可以匹配到other-provider提供的名为 anthropic/claude-3-5-sonnet 的模型。但不会匹配到anthropic下面的claude-3-5-sonnet模型。
 
108
  preferences:
109
  USE_ROUND_ROBIN: true # 是否使用轮询负载均衡,true 为使用,false 为不使用,默认为 true。开启轮训后每次请求模型按照 model 配置的顺序依次请求。与 providers 里面原始的渠道顺序无关。因此你可以设置每个 API key 请求顺序不一样。
110
  AUTO_RETRY: true # 是否自动重试,自动重试下一个提供商,true 为自动重试,false 为不自动重试,默认为 true
111
  RATE_LIMIT: 2/min # 支持限流,每分钟最多请求次数,可以设置为整数,如 2/min,2 次每分钟、5/hour,5 次每小时、10/day,10 次每天,10/month,10 次每月,10/year,10 次每年。默认60/min,选填
 
112
 
113
  # 渠道级加权负载均衡配置示例
114
  - api: sk-KjjI60Yf0JFWtxxxxxxxxxxxxxxwmRWpWpQRo
 
105
  model:
106
  - anthropic/claude-3-5-sonnet # 可以使用的模型名称,仅可以使用名为 anthropic 提供商提供的 claude-3-5-sonnet 模型。其他提供商的 claude-3-5-sonnet 模型不可以使用。这种写法不会匹配到other-provider提供的名为anthropic/claude-3-5-sonnet的模型。
107
  - <anthropic/claude-3-5-sonnet> # 通过在模型名两侧加上尖括号,这样就不会去名为anthropic的渠道下去寻找claude-3-5-sonnet模型,而是将整个 anthropic/claude-3-5-sonnet 作为模型名称。这种写法可以匹配到other-provider提供的名为 anthropic/claude-3-5-sonnet 的模型。但不会匹配到anthropic下面的claude-3-5-sonnet模型。
108
+ - openai-test/text-moderation-latest # 当开启消息道德审查后,可以使用名为 openai-test 渠道下的 text-moderation-latest 模型进行道德审查。
109
  preferences:
110
  USE_ROUND_ROBIN: true # 是否使用轮询负载均衡,true 为使用,false 为不使用,默认为 true。开启轮训后每次请求模型按照 model 配置的顺序依次请求。与 providers 里面原始的渠道顺序无关。因此你可以设置每个 API key 请求顺序不一样。
111
  AUTO_RETRY: true # 是否自动重试,自动重试下一个提供商,true 为自动重试,false 为不自动重试,默认为 true
112
  RATE_LIMIT: 2/min # 支持限流,每分钟最多请求次数,可以设置为整数,如 2/min,2 次每分钟、5/hour,5 次每小时、10/day,10 次每天,10/month,10 次每月,10/year,10 次每年。默认60/min,选填
113
+ ENABLE_MODERATION: true # 是否开启消息道德审查,true 为开启,false 为不开启,默认为 false,当开启后,会对用户的消息进行道德审查,如果发现不当的消息,会返回错误信息。
114
 
115
  # 渠道级加权负载均衡配置示例
116
  - api: sk-KjjI60Yf0JFWtxxxxxxxxxxxxxxwmRWpWpQRo
main.py CHANGED
@@ -24,10 +24,50 @@ from urllib.parse import urlparse
24
  import os
25
  is_debug = bool(os.getenv("DEBUG", False))
26
 
 
 
 
27
  async def create_tables():
28
  async with engine.begin() as conn:
29
  await conn.run_sync(Base.metadata.create_all)
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  @asynccontextmanager
32
  async def lifespan(app: FastAPI):
33
  # 启动时的代码
@@ -79,7 +119,7 @@ async def parse_request_body(request: Request):
79
 
80
  from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
81
  from sqlalchemy.orm import declarative_base, sessionmaker
82
- from sqlalchemy import Column, Integer, String, Float, DateTime, select, Boolean
83
  from sqlalchemy.sql import func
84
 
85
  # 定义数据库模型
@@ -93,6 +133,8 @@ class RequestStat(Base):
93
  token = Column(String)
94
  total_time = Column(Float)
95
  model = Column(String)
 
 
96
  timestamp = Column(DateTime(timezone=True), server_default=func.now())
97
 
98
  class ChannelStat(Base):
@@ -113,6 +155,7 @@ data_dir = os.path.dirname(db_path)
113
  os.makedirs(data_dir, exist_ok=True)
114
 
115
  # 创建异步引擎和会话
 
116
  engine = create_async_engine('sqlite+aiosqlite:///' + db_path, echo=is_debug)
117
  async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
118
 
@@ -132,37 +175,76 @@ class StatsMiddleware(BaseHTTPMiddleware):
132
  start_time = time()
133
 
134
  request.state.parsed_body = await parse_request_body(request)
 
 
135
 
136
  model = "unknown"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  if request.state.parsed_body:
138
  try:
139
  request_model = RequestModel(**request.state.parsed_body)
140
  model = request_model.model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  except RequestValidationError:
142
  pass
143
  except Exception as e:
144
- logger.error(f"Error processing request: {str(e)}")
 
 
 
 
145
 
146
  response = await call_next(request)
147
  process_time = time() - start_time
148
 
149
- endpoint = f"{request.method} {request.url.path}"
150
- client_ip = request.client.host
151
-
152
  # 异步更新数据库
153
- await self.update_stats(endpoint, process_time, client_ip, model, token)
154
 
155
  return response
156
 
157
- async def update_stats(self, endpoint, process_time, client_ip, model, token):
158
  async with self.db as session:
159
- # 为每个请求创建一条新的记录
160
  new_request_stat = RequestStat(
161
  endpoint=endpoint,
162
  ip=client_ip,
163
  token=token,
164
  total_time=process_time,
165
- model=model
 
 
166
  )
167
  session.add(new_request_stat)
168
  await session.commit()
@@ -179,6 +261,14 @@ class StatsMiddleware(BaseHTTPMiddleware):
179
  session.add(channel_stat)
180
  await session.commit()
181
 
 
 
 
 
 
 
 
 
182
  # 配置 CORS 中间件
183
  app.add_middleware(
184
  CORSMiddleware,
@@ -561,7 +651,7 @@ async def images_generations(
561
  return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
562
 
563
  @app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
564
- async def images_generations(
565
  request: ModerationRequest,
566
  token: str = Depends(verify_api_key)
567
  ):
@@ -601,9 +691,6 @@ def generate_api_key():
601
  return JSONResponse(content={"api_key": api_key})
602
 
603
  # 在 /stats 路由中返回成功和失败百分比
604
- from collections import defaultdict
605
- from sqlalchemy import func
606
-
607
  from collections import defaultdict
608
  from sqlalchemy import func, desc, case
609
 
 
24
  import os
25
  is_debug = bool(os.getenv("DEBUG", False))
26
 
27
+ from sqlalchemy import inspect, text
28
+ from sqlalchemy.sql import sqltypes
29
+
30
  async def create_tables():
31
  async with engine.begin() as conn:
32
  await conn.run_sync(Base.metadata.create_all)
33
 
34
+ # 检查并添加缺失的列
35
+ def check_and_add_columns(connection):
36
+ inspector = inspect(connection)
37
+ for table in [RequestStat, ChannelStat]:
38
+ table_name = table.__tablename__
39
+ existing_columns = {col['name']: col['type'] for col in inspector.get_columns(table_name)}
40
+
41
+ for column_name, column in table.__table__.columns.items():
42
+ if column_name not in existing_columns:
43
+ col_type = _map_sa_type_to_sql_type(column.type)
44
+ default = _get_default_sql(column.default)
45
+ connection.execute(text(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {col_type}{default}"))
46
+
47
+ await conn.run_sync(check_and_add_columns)
48
+
49
+ def _map_sa_type_to_sql_type(sa_type):
50
+ type_map = {
51
+ sqltypes.Integer: "INTEGER",
52
+ sqltypes.String: "TEXT",
53
+ sqltypes.Float: "REAL",
54
+ sqltypes.Boolean: "BOOLEAN",
55
+ sqltypes.DateTime: "DATETIME",
56
+ sqltypes.Text: "TEXT"
57
+ }
58
+ return type_map.get(type(sa_type), "TEXT")
59
+
60
+ def _get_default_sql(default):
61
+ if default is None:
62
+ return ""
63
+ if isinstance(default.arg, bool):
64
+ return f" DEFAULT {str(default.arg).upper()}"
65
+ if isinstance(default.arg, (int, float)):
66
+ return f" DEFAULT {default.arg}"
67
+ if isinstance(default.arg, str):
68
+ return f" DEFAULT '{default.arg}'"
69
+ return ""
70
+
71
  @asynccontextmanager
72
  async def lifespan(app: FastAPI):
73
  # 启动时的代码
 
119
 
120
  from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
121
  from sqlalchemy.orm import declarative_base, sessionmaker
122
+ from sqlalchemy import Column, Integer, String, Float, DateTime, select, Boolean, Text
123
  from sqlalchemy.sql import func
124
 
125
  # 定义数据库模型
 
133
  token = Column(String)
134
  total_time = Column(Float)
135
  model = Column(String)
136
+ is_flagged = Column(Boolean, default=False)
137
+ moderated_content = Column(Text)
138
  timestamp = Column(DateTime(timezone=True), server_default=func.now())
139
 
140
  class ChannelStat(Base):
 
155
  os.makedirs(data_dir, exist_ok=True)
156
 
157
  # 创建异步引擎和会话
158
+ # engine = create_async_engine('sqlite+aiosqlite:///' + db_path, echo=False)
159
  engine = create_async_engine('sqlite+aiosqlite:///' + db_path, echo=is_debug)
160
  async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
161
 
 
175
  start_time = time()
176
 
177
  request.state.parsed_body = await parse_request_body(request)
178
+ endpoint = f"{request.method} {request.url.path}"
179
+ client_ip = request.client.host
180
 
181
  model = "unknown"
182
+ enable_moderation = False # 默认不开启道德审查
183
+ is_flagged = False
184
+ moderated_content = ""
185
+
186
+ config = app.state.config
187
+ api_list = app.state.api_list
188
+
189
+ # 根据token决定是否启用道德审查
190
+ if token:
191
+ try:
192
+ api_index = api_list.index(token)
193
+ enable_moderation = safe_get(config, 'api_keys', api_index, "preferences", "ENABLE_MODERATION", default=False)
194
+ except ValueError:
195
+ # token不在api_list中,使用默认值(不开启)
196
+ pass
197
+ else:
198
+ # 如果token为None,检查全局设置
199
+ enable_moderation = config.get('ENABLE_MODERATION', False)
200
+
201
  if request.state.parsed_body:
202
  try:
203
  request_model = RequestModel(**request.state.parsed_body)
204
  model = request_model.model
205
+ moderated_content = request_model.get_last_text_message()
206
+
207
+ if enable_moderation and moderated_content:
208
+ moderation_response = await self.moderate_content(moderated_content, token)
209
+ moderation_result = moderation_response.body
210
+ moderation_data = json.loads(moderation_result)
211
+ is_flagged = moderation_data.get('results', [{}])[0].get('flagged', False)
212
+
213
+ if is_flagged:
214
+ logger.error(f"Content did not pass the moral check: %s", moderated_content)
215
+ process_time = time() - start_time
216
+ await self.update_stats(endpoint, process_time, client_ip, model, token, is_flagged, moderated_content)
217
+ return JSONResponse(
218
+ status_code=400,
219
+ content={"error": "Content did not pass the moral check, please modify and try again."}
220
+ )
221
  except RequestValidationError:
222
  pass
223
  except Exception as e:
224
+ if is_debug:
225
+ import traceback
226
+ traceback.print_exc()
227
+
228
+ logger.error(f"处理请求或进行道德检查时出错: {str(e)}")
229
 
230
  response = await call_next(request)
231
  process_time = time() - start_time
232
 
 
 
 
233
  # 异步更新数据库
234
+ await self.update_stats(endpoint, process_time, client_ip, model, token, is_flagged, moderated_content)
235
 
236
  return response
237
 
238
+ async def update_stats(self, endpoint, process_time, client_ip, model, token, is_flagged, moderated_content):
239
  async with self.db as session:
 
240
  new_request_stat = RequestStat(
241
  endpoint=endpoint,
242
  ip=client_ip,
243
  token=token,
244
  total_time=process_time,
245
+ model=model,
246
+ is_flagged=is_flagged,
247
+ moderated_content=moderated_content
248
  )
249
  session.add(new_request_stat)
250
  await session.commit()
 
261
  session.add(channel_stat)
262
  await session.commit()
263
 
264
+ async def moderate_content(self, content, token):
265
+ moderation_request = ModerationRequest(input=content)
266
+
267
+ # 直接调用 moderations 函数
268
+ response = await moderations(moderation_request, token)
269
+
270
+ return response
271
+
272
  # 配置 CORS 中间件
273
  app.add_middleware(
274
  CORSMiddleware,
 
651
  return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
652
 
653
  @app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
654
+ async def moderations(
655
  request: ModerationRequest,
656
  token: str = Depends(verify_api_key)
657
  ):
 
691
  return JSONResponse(content={"api_key": api_key})
692
 
693
  # 在 /stats 路由中返回成功和失败百分比
 
 
 
694
  from collections import defaultdict
695
  from sqlalchemy import func, desc, case
696
 
models.py CHANGED
@@ -96,4 +96,15 @@ class RequestModel(BaseModel):
96
  n: Optional[int] = 1
97
  user: Optional[str] = None
98
  tool_choice: Optional[Union[str, ToolChoice]] = None
99
- tools: Optional[List[Tool]] = None
 
 
 
 
 
 
 
 
 
 
 
 
96
  n: Optional[int] = 1
97
  user: Optional[str] = None
98
  tool_choice: Optional[Union[str, ToolChoice]] = None
99
+ tools: Optional[List[Tool]] = None
100
+
101
+ def get_last_text_message(self) -> Optional[str]:
102
+ for message in reversed(self.messages):
103
+ if message.content:
104
+ if isinstance(message.content, str):
105
+ return message.content
106
+ elif isinstance(message.content, list):
107
+ for item in reversed(message.content):
108
+ if item.type == "text" and item.text:
109
+ return item.text
110
+ return ""
request.py CHANGED
@@ -8,9 +8,20 @@ import urllib.parse
8
  from models import RequestModel
9
  from utils import c35s, c3s, c3o, c3h, gem, BaseAPI
10
 
 
 
11
  def encode_image(image_path):
12
- with open(image_path, "rb") as image_file:
13
- return base64.b64encode(image_file.read()).decode('utf-8')
 
 
 
 
 
 
 
 
 
14
 
15
  async def get_doc_from_url(url):
16
  filename = urllib.parse.unquote(url.split("/")[-1])
@@ -37,12 +48,28 @@ async def get_encode_image(image_url):
37
  filename = await get_doc_from_url(image_url)
38
  image_path = os.getcwd() + "/" + filename
39
  base64_image = encode_image(image_path)
40
- if filename.endswith(".png"):
41
- prompt = f"data:image/png;base64,{base64_image}"
42
- else:
43
- prompt = f"data:image/jpeg;base64,{base64_image}"
44
  os.remove(image_path)
45
- return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  async def get_image_message(base64_image, engine = None):
48
  if base64_image.startswith("http"):
@@ -59,6 +86,8 @@ async def get_image_message(base64_image, engine = None):
59
  }
60
  }
61
  if "claude" == engine or "vertex-claude" == engine:
 
 
62
  return {
63
  "type": "image",
64
  "source": {
 
8
  from models import RequestModel
9
  from utils import c35s, c3s, c3o, c3h, gem, BaseAPI
10
 
11
+ import imghdr
12
+
13
  def encode_image(image_path):
14
+ with open(image_path, "rb") as image_file:
15
+ file_content = image_file.read()
16
+ file_type = imghdr.what(None, file_content)
17
+ base64_encoded = base64.b64encode(file_content).decode('utf-8')
18
+
19
+ if file_type == 'png':
20
+ return f"data:image/png;base64,{base64_encoded}"
21
+ elif file_type in ['jpeg', 'jpg']:
22
+ return f"data:image/jpeg;base64,{base64_encoded}"
23
+ else:
24
+ raise ValueError(f"不支持的图片格式: {file_type}")
25
 
26
  async def get_doc_from_url(url):
27
  filename = urllib.parse.unquote(url.split("/")[-1])
 
48
  filename = await get_doc_from_url(image_url)
49
  image_path = os.getcwd() + "/" + filename
50
  base64_image = encode_image(image_path)
 
 
 
 
51
  os.remove(image_path)
52
+ return base64_image
53
+
54
+ from PIL import Image
55
+ import io
56
+ def validate_image(image_data, image_type):
57
+ try:
58
+ decoded_image = base64.b64decode(image_data)
59
+ image = Image.open(io.BytesIO(decoded_image))
60
+
61
+ # 检查图片格式是否与声明的类型匹配
62
+ # print("image.format", image.format)
63
+ if image_type == "image/png" and image.format != "PNG":
64
+ raise ValueError("Image is not a valid PNG")
65
+ elif image_type == "image/jpeg" and image.format not in ["JPEG", "JPG"]:
66
+ raise ValueError("Image is not a valid JPEG")
67
+
68
+ # 如果没有异常,则图片有效
69
+ return True
70
+ except Exception as e:
71
+ print(f"Image validation failed: {str(e)}")
72
+ return False
73
 
74
  async def get_image_message(base64_image, engine = None):
75
  if base64_image.startswith("http"):
 
86
  }
87
  }
88
  if "claude" == engine or "vertex-claude" == engine:
89
+ # if not validate_image(base64_image.split(",")[1], image_type):
90
+ # raise ValueError(f"Invalid image format. Expected {image_type}")
91
  return {
92
  "type": "image",
93
  "source": {
utils.py CHANGED
@@ -310,10 +310,10 @@ class BaseAPI:
310
  self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
311
  self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/moderations",) + ("",) * 3)
312
 
313
- def safe_get(data, *keys):
314
  for key in keys:
315
  try:
316
  data = data[key] if isinstance(data, (dict, list)) else data.get(key)
317
  except (KeyError, IndexError, AttributeError, TypeError):
318
- return None
319
  return data
 
310
  self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
311
  self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/moderations",) + ("",) * 3)
312
 
313
+ def safe_get(data, *keys, default=None):
314
  for key in keys:
315
  try:
316
  data = data[key] if isinstance(data, (dict, list)) else data.get(key)
317
  except (KeyError, IndexError, AttributeError, TypeError):
318
+ return default
319
  return data