Moonfanz commited on
Commit
7e8a8d2
·
verified ·
1 Parent(s): 3ef6335

Upload 4 files

Browse files
Files changed (4) hide show
  1. Dockerfile +10 -0
  2. app.py +221 -0
  3. func.py +92 -0
  4. requirements.txt +7 -0
Dockerfile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY . .
9
+
10
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from flask import Flask, request, jsonify, Response, stream_with_context
3
+ import google.generativeai as genai
4
+ import json
5
+ from datetime import datetime
6
+ import os
7
+ import logging
8
+ import func
9
+
10
+ os.environ['TZ'] = 'Asia/Shanghai'
11
+ app = Flask(__name__)
12
+
13
+ app.secret_key = os.urandom(24)
14
+
15
+ PASSWORD = os.environ['password']
16
+
17
+ formatter = logging.Formatter('%(message)s')
18
+
19
+ logger = logging.getLogger(__name__)
20
+ logger.setLevel(logging.INFO)
21
+
22
+ handler = logging.StreamHandler()
23
+ handler.setFormatter(formatter)
24
+
25
+ logger.addHandler(handler)
26
+
27
+ safety_settings = [
28
+ {
29
+ "category": "HARM_CATEGORY_HARASSMENT",
30
+ "threshold": "BLOCK_NONE"
31
+ },
32
+ {
33
+ "category": "HARM_CATEGORY_HATE_SPEECH",
34
+ "threshold": "BLOCK_NONE"
35
+ },
36
+ {
37
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
38
+ "threshold": "BLOCK_NONE"
39
+ },
40
+ {
41
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
42
+ "threshold": "BLOCK_NONE"
43
+ },
44
+ ]
45
+
46
+ class APIKeyManager:
47
+ def __init__(self):
48
+ self.api_keys = os.environ.get('GEMINI_API_KEY').split(',')
49
+ self.current_index = 0
50
+
51
+ def get_available_key(self):
52
+ if self.current_index >= len(self.api_keys):
53
+ self.current_index = 0
54
+ current_key = self.api_keys[self.current_index]
55
+ self.current_index += 1
56
+ return current_key
57
+
58
+ key_manager = APIKeyManager()
59
+ current_api_key = key_manager.get_available_key()
60
+ logger.info(f"Current API key: {current_api_key}")
61
+
62
+ GEMINI_MODELS = [
63
+ {"id": "gemini-pro"},
64
+ {"id": "gemini-pro-vision"},
65
+ {"id": "gemini-1.0-pro"},
66
+ {"id": "gemini-1.0-pro-vision"},
67
+ {"id": "gemini-1.5-pro-002"},
68
+ {"id": "gemini-exp-1114"},
69
+ {"id": "gemini-exp-1121"},
70
+ {"id": "gemini-exp-1206"},
71
+ {"id": "gemini-2.0-flash-exp"},
72
+ {"id": "gemini-2.0-exp"},
73
+ {"id": "gemini-2.0-pro-exp"},
74
+ ]
75
+
76
+ @app.route('/hf/v1/chat/completions', methods=['POST'])
77
+ def chat_completions():
78
+ global current_api_key
79
+ is_authenticated, auth_error, status_code = func.authenticate_request(PASSWORD, request)
80
+ if not is_authenticated:
81
+ return auth_error if auth_error else jsonify({'error': 'Unauthorized'}), status_code if status_code else 401
82
+ try:
83
+ request_data = request.get_json()
84
+ messages = request_data.get('messages', [])
85
+ model = request_data.get('model', 'gemini-exp-1206')
86
+ temperature = request_data.get('temperature', 1)
87
+ max_tokens = request_data.get('max_tokens', 8192)
88
+ stream = request_data.get('stream', False)
89
+
90
+ logger.info(f"\n{model} [r] -> {current_api_key[:11]}...")
91
+
92
+ gemini_history, user_message, error_response = func.process_messages_for_gemini(messages)
93
+
94
+ if error_response:
95
+ print(error_response)
96
+
97
+ genai.configure(api_key=current_api_key)
98
+
99
+ generation_config = {
100
+ "temperature": temperature,
101
+ "max_output_tokens": max_tokens
102
+ }
103
+
104
+ gen_model = genai.GenerativeModel(
105
+ model_name=model,
106
+ generation_config=generation_config,
107
+ safety_settings=safety_settings
108
+ )
109
+
110
+
111
+ if stream:
112
+
113
+ if gemini_history:
114
+ chat_session = gen_model.start_chat(history=gemini_history)
115
+ response = chat_session.send_message(user_message, stream=True)
116
+ else:
117
+ response = gen_model.generate_content(user_message, stream=True)
118
+
119
+ def generate():
120
+ try:
121
+ for chunk in response:
122
+ if chunk.text:
123
+ data = {
124
+ 'choices': [
125
+ {
126
+ 'delta': {
127
+ 'content': chunk.text
128
+ },
129
+ 'finish_reason': None,
130
+ 'index': 0
131
+ }
132
+ ],
133
+ 'object': 'chat.completion.chunk'
134
+ }
135
+
136
+ yield f"data: {json.dumps(data)}\n\n"
137
+ data = {
138
+ 'choices': [
139
+ {
140
+ 'delta': {},
141
+ 'finish_reason': 'stop',
142
+ 'index': 0
143
+ }
144
+ ],
145
+ 'object': 'chat.completion.chunk'
146
+ }
147
+
148
+ yield f"data: {json.dumps(data)}\n\n"
149
+ except Exception as e:
150
+ logger.error(f"Error during streaming: {str(e)}")
151
+
152
+ data = {
153
+ 'error': {
154
+ 'message': str(e),
155
+ 'type': 'internal_server_error'
156
+ }
157
+ }
158
+ yield f"data: {json.dumps(data)}\n\n"
159
+
160
+ return Response(stream_with_context(generate()), mimetype='text/event-stream')
161
+ else:
162
+
163
+ if gemini_history:
164
+ chat_session = gen_model.start_chat(history=gemini_history)
165
+ response = chat_session.send_message(user_message)
166
+ else:
167
+ response = gen_model.generate_content(user_message)
168
+
169
+ try:
170
+ text_content = response.candidates[0].content.parts[0].text
171
+
172
+ except (AttributeError, IndexError, TypeError) as e:
173
+ logger.error(f"Error getting text content: {str(e)}")
174
+
175
+ text_content = "Error: Unable to get text content."
176
+
177
+ response_data = {
178
+ 'id': 'chatcmpl-xxxxxxxxxxxx',
179
+ 'object': 'chat.completion',
180
+ 'created': int(datetime.now().timestamp()),
181
+ 'model': model,
182
+ 'choices': [{
183
+ 'index': 0,
184
+ 'message': {
185
+ 'role': 'assistant',
186
+ 'content': text_content
187
+ },
188
+ 'finish_reason': 'stop'
189
+ }],
190
+ 'usage':{
191
+ 'prompt_tokens': 0,
192
+ 'completion_tokens': 0,
193
+ 'total_tokens': 0
194
+ }
195
+ }
196
+ logger.info(f"Generation Success")
197
+ return jsonify(response_data)
198
+
199
+ except Exception as e:
200
+ logger.error(f"Error in chat completions: {str(e)}")
201
+
202
+ return jsonify({
203
+ 'error': {
204
+ 'message': str(e),
205
+ 'type': 'invalid_request_error'
206
+ }
207
+ }), 500
208
+ finally:
209
+ current_api_key = key_manager.get_available_key()
210
+ logger.info(f"API KEY Switched -> {current_api_key[:11]}...")
211
+
212
+ @app.route('/hf/v1/models', methods=['GET'])
213
+ def list_models():
214
+ is_authenticated, auth_error, status_code = func.authenticate_request(PASSWORD, request)
215
+ if not is_authenticated:
216
+ return auth_error if auth_error else jsonify({'error': 'Unauthorized'}), status_code if status_code else 401
217
+ response = {"object": "list", "data": GEMINI_MODELS}
218
+ return jsonify(response)
219
+
220
+ if __name__ == '__main__':
221
+ app.run(debug=True, host='0.0.0.0', port=int(os.environ.get('PORT', 7860)))
func.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ import base64
3
+ from PIL import Image
4
+ from flask import jsonify
5
+ import logging
6
+ import json
7
+ import re
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ def authenticate_request(password, request):
12
+ auth_header = request.headers.get('Authorization')
13
+
14
+ if not auth_header:
15
+ return False, jsonify({'error': 'Authorization header is missing'}), 401
16
+
17
+ try:
18
+ auth_type, pass_word = auth_header.split(' ', 1)
19
+ except ValueError:
20
+ return False, jsonify({'error': 'Invalid Authorization header format'}), 401
21
+
22
+ if auth_type.lower() != 'bearer':
23
+ return False, jsonify({'error': 'Authorization type must be Bearer'}), 401
24
+
25
+ if pass_word != password:
26
+ return False, jsonify({'error': 'Unauthorized'}), 401
27
+
28
+ return True, None, None
29
+
30
+ def process_messages_for_gemini(messages):
31
+
32
+ gemini_history = []
33
+ for message in messages:
34
+ role = message.get('role')
35
+ content = message.get('content')
36
+
37
+ if isinstance(content, str): # 纯文本
38
+ if role == 'system':
39
+ gemini_history.append({"role": "user", "parts": [content]})
40
+ elif role == 'user':
41
+ gemini_history.append({"role": "user", "parts": [content]})
42
+ elif role == 'assistant':
43
+ gemini_history.append({"role": "model", "parts": [content]})
44
+ elif isinstance(content, list): # 图文
45
+ parts = []
46
+ for item in content:
47
+ if item.get('type') == 'text':
48
+ parts.append(item.get('text'))
49
+ elif item.get('type') == 'image_url':
50
+ image_data = item.get('image_url', {}).get('url', '')
51
+ if image_data.startswith('data:image/'):
52
+ try:
53
+ # 提取 base64 编码和图片类型
54
+ image_type = image_data.split(';')[0].split('/')[1].upper() # 提取图片类型并转为大写
55
+ base64_image = image_data.split(';base64,')[1]
56
+
57
+ image = Image.open(BytesIO(base64.b64decode(base64_image)))
58
+
59
+ # 将图片转换为 RGB 模式
60
+ if image.mode != 'RGB':
61
+ image = image.convert('RGB')
62
+
63
+ # 压缩图像
64
+ if image.width > 2048 or image.height > 2048:
65
+ image.thumbnail((2048, 2048))
66
+
67
+ output_buffer = BytesIO()
68
+ image.save(output_buffer, format=image_type) # 使用原始图片类型保存
69
+ output_buffer.seek(0)
70
+ parts.append(image)
71
+ except Exception as e:
72
+ logger.error(f"Error processing image: {e}")
73
+ return [], None, (jsonify({'error': 'Invalid image data'}), 400)
74
+ else:
75
+ return [], None, (jsonify({'error': 'Invalid image URL format'}), 400)
76
+
77
+ # 根据 role 添加到 gemini_history
78
+ if role in ['user', 'system']:
79
+ gemini_history.append({"role": "user", "parts": parts})
80
+ elif role == 'assistant':
81
+ gemini_history.append({"role": "model", "parts": parts})
82
+ else:
83
+ return [], None, (jsonify({'error': f'Invalid role: {role}'}), 400)
84
+
85
+ # 用户最后一条消息
86
+ if gemini_history:
87
+ user_message = gemini_history[-1]
88
+ gemini_history = gemini_history[:-1] # 历史记录不包含最后一条消息
89
+ else:
90
+ user_message = {"role": "user", "parts": [""]}
91
+
92
+ return gemini_history, user_message, None
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Flask==2.0.3
2
+ Flask-CORS==3.0.10
3
+ requests==2.26.0
4
+ Werkzeug==2.0.3
5
+ google==3.0.0
6
+ google-generativeai==0.8.3
7
+ pillow==10.4.0