rajeshchoudharyt's picture
Add application file
3171fa3
raw
history blame
1.64 kB
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from enum import Enum
import os
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(
"dunzhang/stella_en_400M_v5",
trust_remote_code=True,
device="cpu",
config_kwargs={"use_memory_efficient_attention": False, "unpad_inputs": False}
)
class Enum(str, Enum):
s2p_query = "s2p_query" # sentence-to-sentence
s2s_query = "s2s_query" # sentence-to-passage, Q&A
class Embedding(BaseModel):
input: list[str]
embedding_type: Enum = None
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["POST"],
allow_headers=["Authorization"]
)
def parse(data):
result = []
for dimension in data:
temp = []
for val in dimension:
temp.append(round(val, 8))
result.append(temp)
return result
@app.post("/embeddings/")
async def get_embedding(embedding: Embedding, req: Request):
token = req.headers.get("Authorization")
if os.environ.get('token') != token[7:]:
raise HTTPException(status_code=401, detail="Unauthorized.")
if model == None:
raise HTTPException(status_code=400, detail="Model load failed.")
if embedding.embedding_type == None:
data = model.encode(embedding.input).tolist()
return parse(data)
else:
data = model.encode(embedding.input, prompt_name=embedding.embedding_type).tolist()
return parse(data)