File size: 2,938 Bytes
902ad9c f3fc705 e8f71ec 902ad9c e8f71ec 902ad9c 5e6c429 e8f71ec 902ad9c ec9c2f2 902ad9c ec9c2f2 f3fc705 e8f71ec f3fc705 902ad9c f3fc705 902ad9c f3fc705 902ad9c f3fc705 902ad9c f3fc705 902ad9c f3fc705 902ad9c f3fc705 902ad9c 81d492d 902ad9c f3fc705 902ad9c f3fc705 902ad9c f3fc705 b4aec05 902ad9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from ctranslate2 import Translator
from typing import Union
from fastapi import FastAPI
from pydantic import BaseModel
def average_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(
~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
# text-ada replacement
embeddingTokenizer = AutoTokenizer.from_pretrained(
'./multilingual-e5-base')
embeddingModel = AutoModel.from_pretrained('./multilingual-e5-base')
# chatGpt replacement
inferenceTokenizer = AutoTokenizer.from_pretrained(
"./ct2fast-flan-alpaca-xl")
inferenceTranslator = Translator(
"./ct2fast-flan-alpaca-xl", compute_type="int8", device="cpu", inter_threads=1, intra_threads=2)
class EmbeddingRequest(BaseModel):
input: Union[str, None] = None
class TokensCountRequest(BaseModel):
input: Union[str, None] = None
class InferenceRequest(BaseModel):
input: Union[str, None] = None
max_length: Union[int, None] = 0
app = FastAPI()
@app.get("/")
async def root():
return {"message": "Hello World"}
@app.post("/text-embedding")
async def text_embedding(request: EmbeddingRequest):
input = request.input
# Process the input data
batch_dict = embeddingTokenizer([input], max_length=512,
padding=True, truncation=True, return_tensors='pt')
outputs = embeddingModel(**batch_dict)
embeddings = average_pool(outputs.last_hidden_state,
batch_dict['attention_mask'])
# create response
return {
'embedding': embeddings[0].tolist()
}
@app.post('/inference')
async def inference(request: InferenceRequest):
input_text = request.input
max_length = 256
try:
max_length = int(request.max_length)
max_length = min(1024, max_length)
except:
pass
# process request
input_tokens = inferenceTokenizer.convert_ids_to_tokens(
inferenceTokenizer.encode(input_text))
results = inferenceTranslator.translate_batch(
[input_tokens], max_input_length=0, max_decoding_length=max_length, num_hypotheses=1, repetition_penalty=1.3, sampling_topk=30, sampling_temperature=1.1, use_vmap=True, disable_unk=True)
output_tokens = results[0].hypotheses[0]
output_text = inferenceTokenizer.decode(
inferenceTokenizer.convert_tokens_to_ids(output_tokens))
# create response
return {
'generated_text': output_text
}
@app.post('/tokens-count')
async def tokens_count(request: TokensCountRequest):
input_text = request.input
tokens = inferenceTokenizer.convert_ids_to_tokens(
inferenceTokenizer.encode(input_text))
# create response
return {
'tokens': tokens,
'total': len(tokens)
}
|