|
from fastapi import FastAPI, HTTPException, Request |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pydantic import BaseModel |
|
from enum import Enum |
|
import os |
|
from sentence_transformers import SentenceTransformer |
|
|
|
model = SentenceTransformer( |
|
"dunzhang/stella_en_400M_v5", |
|
trust_remote_code=True, |
|
device="cpu", |
|
config_kwargs={"use_memory_efficient_attention": False, "unpad_inputs": False} |
|
) |
|
|
|
class Enum(str, Enum): |
|
s2p_query = "s2p_query" |
|
s2s_query = "s2s_query" |
|
|
|
class Embedding(BaseModel): |
|
input: list[str] |
|
embedding_type: Enum = None |
|
|
|
|
|
app = FastAPI() |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["POST"], |
|
allow_headers=["Authorization"] |
|
) |
|
|
|
def parse(data): |
|
result = [] |
|
for dimension in data: |
|
temp = [] |
|
for val in dimension: |
|
temp.append(round(val, 8)) |
|
result.append(temp) |
|
return result |
|
|
|
|
|
@app.post("/embeddings/") |
|
async def get_embedding(embedding: Embedding, req: Request): |
|
|
|
token = req.headers.get("Authorization") |
|
if os.environ.get('token') != token[7:]: |
|
raise HTTPException(status_code=401, detail="Unauthorized.") |
|
|
|
if model == None: |
|
raise HTTPException(status_code=400, detail="Model load failed.") |
|
|
|
if embedding.embedding_type == None: |
|
data = model.encode(embedding.input).tolist() |
|
return parse(data) |
|
else: |
|
data = model.encode(embedding.input, prompt_name=embedding.embedding_type).tolist() |
|
return parse(data) |
|
|