from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from enum import Enum import os from sentence_transformers import SentenceTransformer model = SentenceTransformer( "dunzhang/stella_en_400M_v5", trust_remote_code=True, device="cpu", config_kwargs={"use_memory_efficient_attention": False, "unpad_inputs": False} ) class Enum(str, Enum): s2p_query = "s2p_query" # sentence-to-sentence s2s_query = "s2s_query" # sentence-to-passage, Q&A class Embedding(BaseModel): input: list[str] embedding_type: Enum = None app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["POST"], allow_headers=["Authorization"] ) def parse(data): result = [] for dimension in data: temp = [] for val in dimension: temp.append(round(val, 8)) result.append(temp) return result @app.post("/embeddings/") async def get_embedding(embedding: Embedding, req: Request): token = req.headers.get("Authorization") if os.environ.get('token') != token[7:]: raise HTTPException(status_code=401, detail="Unauthorized.") if model == None: raise HTTPException(status_code=400, detail="Model load failed.") if embedding.embedding_type == None: data = model.encode(embedding.input).tolist() return parse(data) else: data = model.encode(embedding.input, prompt_name=embedding.embedding_type).tolist() return parse(data)