sanbo commited on
Commit
8e90b2d
·
1 Parent(s): 1c96ca8

update sth. at 2025-02-03 19:34:12

Browse files
Files changed (3) hide show
  1. app.py +1 -0
  2. app.py——ok_baks +146 -0
  3. requirements.txt +2 -1
app.py CHANGED
@@ -86,6 +86,7 @@ app.add_middleware(
86
  allow_headers=["*"],
87
  )
88
 
 
89
  @app.post("/generate_embeddings", response_model=EmbeddingResponse)
90
  @app.post("/api/v1/embeddings", response_model=EmbeddingResponse)
91
  @app.post("/hf/v1/embeddings", response_model=EmbeddingResponse)
 
86
  allow_headers=["*"],
87
  )
88
 
89
+ @app.post("/v1/embeddings", response_model=EmbeddingResponse)
90
  @app.post("/generate_embeddings", response_model=EmbeddingResponse)
91
  @app.post("/api/v1/embeddings", response_model=EmbeddingResponse)
92
  @app.post("/hf/v1/embeddings", response_model=EmbeddingResponse)
app.py——ok_baks ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ import torch
4
+ import gradio as gr
5
+ from fastapi import FastAPI, HTTPException
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from pydantic import BaseModel
8
+ from typing import List, Dict
9
+ from functools import lru_cache
10
+ import numpy as np
11
+ from threading import Lock
12
+ import uvicorn
13
+
14
+ class EmbeddingRequest(BaseModel):
15
+ input: str
16
+ model: str = "jinaai/jina-embeddings-v3"
17
+
18
+ class EmbeddingResponse(BaseModel):
19
+ status: str
20
+ embeddings: List[List[float]]
21
+
22
+ class EmbeddingService:
23
+ def __init__(self):
24
+ self.model_name = "jinaai/jina-embeddings-v3"
25
+ self.max_length = 512
26
+ self.device = torch.device("cpu")
27
+ self.model = None
28
+ self.tokenizer = None
29
+ self.lock = Lock()
30
+ self.setup_logging()
31
+ torch.set_num_threads(4) # CPU优化
32
+
33
+ def setup_logging(self):
34
+ logging.basicConfig(
35
+ level=logging.INFO,
36
+ format='%(asctime)s - %(levelname)s - %(message)s'
37
+ )
38
+ self.logger = logging.getLogger(__name__)
39
+
40
+ async def initialize(self):
41
+ try:
42
+ from transformers import AutoTokenizer, AutoModel
43
+ self.tokenizer = AutoTokenizer.from_pretrained(
44
+ self.model_name,
45
+ trust_remote_code=True
46
+ )
47
+ self.model = AutoModel.from_pretrained(
48
+ self.model_name,
49
+ trust_remote_code=True
50
+ ).to(self.device)
51
+ self.model.eval()
52
+ torch.set_grad_enabled(False)
53
+ self.logger.info(f"模型加载成功,使用设备: {self.device}")
54
+ except Exception as e:
55
+ self.logger.error(f"模型初始化失败: {str(e)}")
56
+ raise
57
+
58
+ @lru_cache(maxsize=1000)
59
+ def get_embedding(self, text: str) -> List[float]:
60
+ """同步生成嵌入向量,带缓存"""
61
+ with self.lock:
62
+ try:
63
+ inputs = self.tokenizer(
64
+ text,
65
+ return_tensors="pt",
66
+ truncation=True,
67
+ max_length=self.max_length,
68
+ padding=True
69
+ )
70
+
71
+ with torch.no_grad():
72
+ outputs = self.model(**inputs).last_hidden_state.mean(dim=1)
73
+ return outputs.numpy().tolist()[0]
74
+ except Exception as e:
75
+ self.logger.error(f"生成嵌入向量失败: {str(e)}")
76
+ raise
77
+
78
+ embedding_service = EmbeddingService()
79
+ app = FastAPI()
80
+
81
+ app.add_middleware(
82
+ CORSMiddleware,
83
+ allow_origins=["*"],
84
+ allow_credentials=True,
85
+ allow_methods=["*"],
86
+ allow_headers=["*"],
87
+ )
88
+
89
+ @app.post("/generate_embeddings", response_model=EmbeddingResponse)
90
+ @app.post("/api/v1/embeddings", response_model=EmbeddingResponse)
91
+ @app.post("/hf/v1/embeddings", response_model=EmbeddingResponse)
92
+ @app.post("/api/v1/chat/completions", response_model=EmbeddingResponse)
93
+ @app.post("/hf/v1/chat/completions", response_model=EmbeddingResponse)
94
+ async def generate_embeddings(request: EmbeddingRequest):
95
+ try:
96
+ # 使用run_in_executor避免事件循环问题
97
+ embedding = await asyncio.get_running_loop().run_in_executor(
98
+ None,
99
+ embedding_service.get_embedding,
100
+ request.input
101
+ )
102
+ return EmbeddingResponse(
103
+ status="success",
104
+ embeddings=[embedding]
105
+ )
106
+ except Exception as e:
107
+ raise HTTPException(status_code=500, detail=str(e))
108
+
109
+ @app.get("/")
110
+ async def root():
111
+ return {
112
+ "status": "active",
113
+ "model": embedding_service.model_name,
114
+ "device": str(embedding_service.device)
115
+ }
116
+
117
+ def gradio_interface(text: str) -> Dict:
118
+ try:
119
+ embedding = embedding_service.get_embedding(text)
120
+ return {
121
+ "status": "success",
122
+ "embeddings": [embedding]
123
+ }
124
+ except Exception as e:
125
+ return {
126
+ "status": "error",
127
+ "message": str(e)
128
+ }
129
+
130
+ iface = gr.Interface(
131
+ fn=gradio_interface,
132
+ inputs=gr.Textbox(lines=3, label="输入文本"),
133
+ outputs=gr.JSON(label="嵌入向量结果"),
134
+ title="Jina Embeddings V3",
135
+ description="使用jina-embeddings-v3模型生成文本嵌入向量",
136
+ examples=[["这是一个测试句子。"]]
137
+ )
138
+
139
+ @app.on_event("startup")
140
+ async def startup_event():
141
+ await embedding_service.initialize()
142
+
143
+ if __name__ == "__main__":
144
+ asyncio.run(embedding_service.initialize())
145
+ gr.mount_gradio_app(app, iface, path="/ui")
146
+ uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)
requirements.txt CHANGED
@@ -9,4 +9,5 @@ numpy
9
  python-multipart
10
  sentencepiece
11
  safetensors
12
-
 
 
9
  python-multipart
10
  sentencepiece
11
  safetensors
12
+ pydantic
13
+ click