sanbo110 commited on
Commit
087fb1b
·
verified ·
1 Parent(s): d608d26

Create main.go

Browse files
Files changed (1) hide show
  1. main.go +496 -0
main.go ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import (
4
+ "bufio"
5
+ "embed"
6
+ "encoding/json"
7
+ "errors"
8
+ "fmt"
9
+ "io"
10
+ "io/fs"
11
+ "log"
12
+ "net/http"
13
+ "os"
14
+ "strings"
15
+ "time"
16
+
17
+ "github.com/gin-gonic/gin"
18
+ "github.com/joho/godotenv"
19
+ )
20
+
21
+ //go:embed web/*
22
+ var staticFiles embed.FS
23
+
24
+ type Config struct {
25
+ APIPrefix string
26
+ APIKey string
27
+ MaxRetryCount int
28
+ RetryDelay time.Duration
29
+ FakeHeaders map[string]string
30
+ }
31
+
32
+ var config Config
33
+
34
+ func init() {
35
+ godotenv.Load()
36
+ config = Config{
37
+ APIKey: getEnv("API_KEY", ""),
38
+ MaxRetryCount: getIntEnv("MAX_RETRY_COUNT", 3),
39
+ RetryDelay: getDurationEnv("RETRY_DELAY", 5000),
40
+ FakeHeaders: map[string]string{
41
+ "Accept": "*/*",
42
+ "Accept-Encoding": "gzip, deflate, br, zstd",
43
+ "Accept-Language": "zh-CN,zh;q=0.9",
44
+ "Origin": "https://duckduckgo.com/",
45
+ "Cookie": "l=wt-wt; ah=wt-wt; dcm=6",
46
+ "Dnt": "1",
47
+ "Priority": "u=1, i",
48
+ "Referer": "https://duckduckgo.com/",
49
+ "Sec-Ch-Ua": `"Microsoft Edge";v="129", "Not(A:Brand";v="8", "Chromium";v="129"`,
50
+ "Sec-Ch-Ua-Mobile": "?0",
51
+ "Sec-Ch-Ua-Platform": `"Windows"`,
52
+ "Sec-Fetch-Dest": "empty",
53
+ "Sec-Fetch-Mode": "cors",
54
+ "Sec-Fetch-Site": "same-origin",
55
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36",
56
+ },
57
+ }
58
+ }
59
+
60
+ func authMiddleware() gin.HandlerFunc {
61
+ return func(c *gin.Context) {
62
+ apiKey := c.GetHeader("Authorization")
63
+ if apiKey == "" {
64
+ apiKey = c.Query("api_key")
65
+ }
66
+
67
+ // 当未提供或者 API Key 不合法时,允许匿名访问
68
+ if apiKey == "" || !strings.HasPrefix(apiKey, "Bearer ") {
69
+ c.Next() // 未提供 API 密钥时允许继续请求
70
+ return
71
+ }
72
+
73
+ apiKey = strings.TrimPrefix(apiKey, "Bearer ")
74
+
75
+ if apiKey != config.APIKey {
76
+ c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"})
77
+ c.Abort()
78
+ return
79
+ }
80
+
81
+ c.Next()
82
+ }
83
+ }
84
+
85
+ func main() {
86
+ r := gin.Default()
87
+ r.Use(corsMiddleware())
88
+ // 1. 映射 Web 目录到 /web 路由
89
+ //r.Static("/web", "./web") // 确保 ./web 文件夹存在,并包含 index.html
90
+ subFS, err := fs.Sub(staticFiles, "web")
91
+ if err != nil {
92
+ log.Fatal(err)
93
+ }
94
+ r.StaticFS("/web", http.FS(subFS))
95
+
96
+ // 2. 根路径重定向到 /web
97
+ r.GET("/", func(c *gin.Context) {
98
+ c.Redirect(http.StatusMovedPermanently, "/web")
99
+ })
100
+ // r.GET("/", func(c *gin.Context) {
101
+ // c.JSON(http.StatusOK, gin.H{"message": "API 服务运行中~"})
102
+ // })
103
+ // 3. 健康检查
104
+ r.GET("/ping", func(c *gin.Context) {
105
+ c.JSON(http.StatusOK, gin.H{"message": "pong"})
106
+ })
107
+ // 4. API 路由组
108
+ // authorized := r.Group("/")
109
+ // authorized.Use(authMiddleware())
110
+ // {
111
+ // authorized.GET("/hf/v1/models", handleModels)
112
+ // authorized.POST("/hf/v1/chat/completions", handleCompletion)
113
+ // }
114
+ apiGroup := r.Group("/")
115
+ apiGroup.Use(authMiddleware()) // 可以选择性地提供 API 密钥
116
+ {
117
+ // 原始路径 /hf/v1/*
118
+ apiGroup.GET("/hf/v1/models", handleModels)
119
+ apiGroup.POST("/hf/v1/chat/completions", handleCompletion)
120
+
121
+ // 新路径 /api/v1/*
122
+ apiGroup.GET("/api/v1/models", handleModels)
123
+ apiGroup.POST("/api/v1/chat/completions", handleCompletion)
124
+
125
+ // 新路径 /v1/*
126
+ apiGroup.GET("/v1/models", handleModels)
127
+ apiGroup.POST("/v1/chat/completions", handleCompletion)
128
+
129
+ // 新路径 /completions
130
+ apiGroup.POST("/completions", handleCompletion)
131
+ }
132
+ // 5. 从环境变量中读取端口号
133
+ port := os.Getenv("PORT")
134
+ if port == "" {
135
+ port = "7860"
136
+ }
137
+ r.Run(":" + port)
138
+ }
139
+
140
+ func handleModels(c *gin.Context) {
141
+ models := []gin.H{
142
+ {"id": "gpt-4o-mini", "object": "model", "owned_by": "ddg"},
143
+ {"id": "claude-3-haiku", "object": "model", "owned_by": "ddg"},
144
+ {"id": "llama-3.1-70b", "object": "model", "owned_by": "ddg"},
145
+ {"id": "mixtral-8x7b", "object": "model", "owned_by": "ddg"},
146
+ }
147
+ c.JSON(http.StatusOK, gin.H{"object": "list", "data": models})
148
+ }
149
+
150
+ func handleCompletion(c *gin.Context) {
151
+ var req struct {
152
+ Model string `json:"model"`
153
+ Messages []struct {
154
+ Role string `json:"role"`
155
+ Content interface{} `json:"content"`
156
+ } `json:"messages"`
157
+ Stream bool `json:"stream"`
158
+ }
159
+
160
+ if err := c.ShouldBindJSON(&req); err != nil {
161
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
162
+ return
163
+ }
164
+
165
+ model := convertModel(req.Model)
166
+ content := prepareMessages(req.Messages)
167
+ // log.Printf("messages: %v", content)
168
+
169
+ reqBody := map[string]interface{}{
170
+ "model": model,
171
+ "messages": []map[string]interface{}{
172
+ {
173
+ "role": "user",
174
+ "content": content,
175
+ },
176
+ },
177
+ }
178
+
179
+ body, err := json.Marshal(reqBody)
180
+ if err != nil {
181
+ c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("请求体序列化失败: %v", err)})
182
+ return
183
+ }
184
+
185
+ token, err := requestToken()
186
+ if err != nil {
187
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "无法获取token"})
188
+ return
189
+ }
190
+
191
+ upstreamReq, err := http.NewRequest("POST", "https://duckduckgo.com/duckchat/v1/chat", strings.NewReader(string(body)))
192
+ if err != nil {
193
+ c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("创建请求失败: %v", err)})
194
+ return
195
+ }
196
+
197
+ for k, v := range config.FakeHeaders {
198
+ upstreamReq.Header.Set(k, v)
199
+ }
200
+ upstreamReq.Header.Set("x-vqd-4", token)
201
+ upstreamReq.Header.Set("Content-Type", "application/json")
202
+
203
+ client := &http.Client{
204
+ Timeout: 30 * time.Second,
205
+ }
206
+
207
+ resp, err := client.Do(upstreamReq)
208
+ if err != nil {
209
+ c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("请求失败: %v", err)})
210
+ return
211
+ }
212
+ defer resp.Body.Close()
213
+
214
+ if req.Stream {
215
+ // 启用 SSE 流式响应
216
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
217
+ c.Writer.Header().Set("Cache-Control", "no-cache")
218
+ c.Writer.Header().Set("Connection", "keep-alive")
219
+
220
+ flusher, ok := c.Writer.(http.Flusher)
221
+ if !ok {
222
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "Streaming not supported"})
223
+ return
224
+ }
225
+
226
+ reader := bufio.NewReader(resp.Body)
227
+ for {
228
+ line, err := reader.ReadString('\n')
229
+ if err != nil {
230
+ if err != io.EOF {
231
+ log.Printf("读取流式响应失败: %v", err)
232
+ }
233
+ break
234
+ }
235
+ if strings.HasPrefix(line, "data: ") {
236
+ // 解析响应中的 JSON 数据块
237
+ line = strings.TrimPrefix(line, "data: ")
238
+ line = strings.TrimSpace(line)
239
+ // 忽略非 JSON 数据块(例如特殊标记 [DONE])
240
+ if line == "[DONE]" {
241
+ //log.Printf("响应行 DONE, 即将跳过")
242
+ break
243
+ }
244
+ var chunk map[string]interface{}
245
+ if err := json.Unmarshal([]byte(line), &chunk); err != nil {
246
+ log.Printf("解析响应行失败: %v", err)
247
+ continue
248
+ }
249
+
250
+ // 检查 chunk 是否包含 message
251
+ if msg, exists := chunk["message"]; exists && msg != nil {
252
+ if msgStr, ok := msg.(string); ok {
253
+ response := map[string]interface{}{
254
+ "id": "chatcmpl-QXlha2FBbmROaXhpZUFyZUF3ZXNvbWUK",
255
+ "object": "chat.completion.chunk",
256
+ "created": time.Now().Unix(),
257
+ "model": model,
258
+ "choices": []map[string]interface{}{
259
+ {
260
+ "index": 0,
261
+ "delta": map[string]string{
262
+ "content": msgStr,
263
+ },
264
+ "finish_reason": nil,
265
+ },
266
+ },
267
+ }
268
+ // 将响应格式化为 SSE 数据块
269
+ sseData, _ := json.Marshal(response)
270
+ sseMessage := fmt.Sprintf("data: %s\n\n", sseData)
271
+
272
+ // 发送数据并刷新缓冲区
273
+ _, writeErr := c.Writer.Write([]byte(sseMessage))
274
+ if writeErr != nil {
275
+ log.Printf("写入响应失败: %v", writeErr)
276
+ break
277
+ }
278
+ flusher.Flush()
279
+ } else {
280
+ log.Printf("chunk[message] 不是字符串: %v", msg)
281
+ }
282
+ } else {
283
+ // 解析行中有空行
284
+ log.Println("chunk 中未包含 message 或 message 为 nil")
285
+ }
286
+ }
287
+ }
288
+ } else {
289
+ // 非流式响应,返回完整的 JSON
290
+ var fullResponse strings.Builder
291
+ reader := bufio.NewReader(resp.Body)
292
+
293
+ for {
294
+ line, err := reader.ReadString('\n')
295
+ if err == io.EOF {
296
+ break
297
+ } else if err != nil {
298
+ log.Printf("读取响应失败: %v", err)
299
+ break
300
+ }
301
+
302
+ if strings.HasPrefix(line, "data: ") {
303
+ line = strings.TrimPrefix(line, "data: ")
304
+ line = strings.TrimSpace(line)
305
+
306
+ if line == "[DONE]" {
307
+ break
308
+ }
309
+
310
+ var chunk map[string]interface{}
311
+ if err := json.Unmarshal([]byte(line), &chunk); err != nil {
312
+ log.Printf("解析响应行失败: %v", err)
313
+ continue
314
+ }
315
+
316
+ if message, exists := chunk["message"]; exists {
317
+ if msgStr, ok := message.(string); ok {
318
+ fullResponse.WriteString(msgStr)
319
+ }
320
+ }
321
+ }
322
+ }
323
+
324
+ // 返回完整 JSON 响应
325
+ response := map[string]interface{}{
326
+ "id": "chatcmpl-QXlha2FBbmROaXhpZUFyZUF3ZXNvbWUK",
327
+ "object": "chat.completion",
328
+ "created": time.Now().Unix(),
329
+ "model": model,
330
+ "usage": map[string]int{
331
+ "prompt_tokens": 0,
332
+ "completion_tokens": 0,
333
+ "total_tokens": 0,
334
+ },
335
+ "choices": []map[string]interface{}{
336
+ {
337
+ "message": map[string]string{
338
+ "role": "assistant",
339
+ "content": fullResponse.String(),
340
+ },
341
+ "index": 0,
342
+ },
343
+ },
344
+ }
345
+
346
+ c.JSON(http.StatusOK, response)
347
+ }
348
+ }
349
+
350
+
351
+ func requestToken() (string, error) {
352
+ url := "https://duckduckgo.com/duckchat/v1/status"
353
+ client := &http.Client{
354
+ Timeout: 15 * time.Second, // 设置超时时间
355
+ }
356
+
357
+ maxRetries := config.MaxRetryCount
358
+ retryDelay := config.RetryDelay
359
+
360
+ for attempt := 0; attempt < maxRetries; attempt++ {
361
+ if attempt > 0 {
362
+ log.Printf("requestToken: 第 %d 次重试,等待 %v...", attempt, retryDelay)
363
+ time.Sleep(retryDelay)
364
+ }
365
+ log.Printf("requestToken: 发送 GET 请求到 %s", url)
366
+
367
+ // 创建请求
368
+ req, err := http.NewRequest("GET", url, nil)
369
+ if err != nil {
370
+ log.Printf("requestToken: 创建请求失败: %v", err)
371
+ return "", fmt.Errorf("无法创建请求: %w", err)
372
+ }
373
+
374
+ // 添加假头部
375
+ for k, v := range config.FakeHeaders {
376
+ req.Header.Set(k, v)
377
+ }
378
+ req.Header.Set("x-vqd-accept", "1")
379
+
380
+ // 发送请求
381
+ resp, err := client.Do(req)
382
+ if err != nil {
383
+ log.Printf("requestToken: 请求失败: %v", err)
384
+ continue // 网络通信失败,进行重试
385
+ }
386
+ defer resp.Body.Close()
387
+
388
+ // 检查状态码是否为 200
389
+ if resp.StatusCode != http.StatusOK {
390
+ bodyBytes, _ := io.ReadAll(resp.Body) // 读取响应体,错误时也需要记录响应内容
391
+ bodyString := string(bodyBytes)
392
+ log.Printf("requestToken: 非200响应,状态码=%d, 响应内容: %s", resp.StatusCode, bodyString)
393
+ continue
394
+ }
395
+
396
+ // 尝试从头部提取 token
397
+ token := resp.Header.Get("x-vqd-4")
398
+ if token == "" {
399
+ log.Println("requestToken: 响应中未包含 x-vqd-4 头部")
400
+ bodyBytes, _ := io.ReadAll(resp.Body)
401
+ bodyString := string(bodyBytes)
402
+ log.Printf("requestToken: 响应内容: %s", bodyString)
403
+ continue
404
+ }
405
+
406
+ // 成功获取到 token
407
+ log.Printf("requestToken: 成功获取到 token: %s", token)
408
+ return token, nil
409
+ }
410
+
411
+ // 如果所有重试均失败,返回错误
412
+ return "", errors.New("requestToken: 无法获取到 token,多次重试仍失败")
413
+ }
414
+
415
+ func prepareMessages(messages []struct {
416
+ Role string `json:"role"`
417
+ Content interface{} `json:"content"`
418
+ }) string {
419
+ var contentBuilder strings.Builder
420
+
421
+ for _, msg := range messages {
422
+ // Determine the role - 'system' becomes 'user'
423
+ role := msg.Role
424
+ if role == "system" {
425
+ role = "user"
426
+ }
427
+
428
+ // Process the content as string
429
+ contentStr := ""
430
+ switch v := msg.Content.(type) {
431
+ case string:
432
+ contentStr = v
433
+ case []interface{}:
434
+ for _, item := range v {
435
+ if itemMap, ok := item.(map[string]interface{}); ok {
436
+ if text, exists := itemMap["text"].(string); exists {
437
+ contentStr += text
438
+ }
439
+ }
440
+ }
441
+ default:
442
+ contentStr = fmt.Sprintf("%v", msg.Content)
443
+ }
444
+
445
+ // Append the role and content to the builder
446
+ contentBuilder.WriteString(fmt.Sprintf("%s:%s;\r\n", role, contentStr))
447
+ }
448
+
449
+ return contentBuilder.String()
450
+ }
451
+
452
+ func convertModel(inputModel string) string {
453
+ switch strings.ToLower(inputModel) {
454
+ case "claude-3-haiku":
455
+ return "claude-3-haiku-20240307"
456
+ case "llama-3.1-70b":
457
+ return "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
458
+ case "mixtral-8x7b":
459
+ return "mistralai/Mixtral-8x7B-Instruct-v0.1"
460
+ default:
461
+ return "gpt-4o-mini"
462
+ }
463
+ }
464
+
465
+ func corsMiddleware() gin.HandlerFunc {
466
+ return func(c *gin.Context) {
467
+ c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
468
+ c.Writer.Header().Set("Access-Control-Allow-Methods", "*")
469
+ c.Writer.Header().Set("Access-Control-Allow-Headers", "*")
470
+ if c.Request.Method == http.MethodOptions {
471
+ c.AbortWithStatus(http.StatusNoContent)
472
+ return
473
+ }
474
+ c.Next()
475
+ }
476
+ }
477
+
478
+ func getEnv(key, fallback string) string {
479
+ if value, exists := os.LookupEnv(key); exists {
480
+ return value
481
+ }
482
+ return fallback
483
+ }
484
+
485
+ func getIntEnv(key string, fallback int) int {
486
+ if value, exists := os.LookupEnv(key); exists {
487
+ var intValue int
488
+ fmt.Sscanf(value, "%d", &intValue)
489
+ return intValue
490
+ }
491
+ return fallback
492
+ }
493
+
494
+ func getDurationEnv(key string, fallback int) time.Duration {
495
+ return time.Duration(getIntEnv(key, fallback)) * time.Millisecond
496
+ }