vasilee's picture
Update main.py
7a6907a
raw
history blame
1.3 kB
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from typing import Union
from fastapi import FastAPI
from pydantic import BaseModel
def average_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(
~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
# text-ada replacement
embeddingTokenizer = AutoTokenizer.from_pretrained(
'./multilingual-e5-base')
embeddingModel = AutoModel.from_pretrained('./multilingual-e5-base')
class EmbeddingRequest(BaseModel):
input: Union[str, None] = None
app = FastAPI()
@app.get("/")
async def root():
return {"message": "Hello World"}
@app.post("/text-embedding")
async def text_embedding(request: EmbeddingRequest):
input = request.input
# Process the input data
batch_dict = embeddingTokenizer([input], max_length=512,
padding=True, truncation=True, return_tensors='pt')
outputs = embeddingModel(**batch_dict)
embeddings = average_pool(outputs.last_hidden_state,
batch_dict['attention_mask'])
# create response
return {
'embedding': embeddings[0].tolist()
}