File size: 2,938 Bytes
902ad9c
 
 
f3fc705
 
 
 
e8f71ec
 
902ad9c
 
 
 
 
e8f71ec
 
902ad9c
 
5e6c429
 
e8f71ec
902ad9c
 
ec9c2f2
902ad9c
ec9c2f2
f3fc705
 
 
 
 
 
 
 
 
 
 
 
 
 
e8f71ec
f3fc705
902ad9c
 
f3fc705
 
 
902ad9c
f3fc705
 
 
 
902ad9c
 
 
 
 
 
 
 
f3fc705
 
902ad9c
 
 
 
f3fc705
 
 
902ad9c
 
f3fc705
902ad9c
 
 
 
f3fc705
902ad9c
 
 
 
81d492d
902ad9c
 
 
 
 
f3fc705
 
902ad9c
 
 
 
f3fc705
 
 
902ad9c
 
 
 
f3fc705
b4aec05
902ad9c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from ctranslate2 import Translator
from typing import Union

from fastapi import FastAPI
from pydantic import BaseModel


def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(
        ~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


# text-ada replacement
embeddingTokenizer = AutoTokenizer.from_pretrained(
    './multilingual-e5-base')
embeddingModel = AutoModel.from_pretrained('./multilingual-e5-base')

# chatGpt replacement
inferenceTokenizer = AutoTokenizer.from_pretrained(
    "./ct2fast-flan-alpaca-xl")
inferenceTranslator = Translator(
    "./ct2fast-flan-alpaca-xl", compute_type="int8", device="cpu", inter_threads=1, intra_threads=2)


class EmbeddingRequest(BaseModel):
    input: Union[str, None] = None


class TokensCountRequest(BaseModel):
    input: Union[str, None] = None


class InferenceRequest(BaseModel):
    input: Union[str, None] = None
    max_length: Union[int, None] = 0


app = FastAPI()


@app.get("/")
async def root():
    return {"message": "Hello World"}


@app.post("/text-embedding")
async def text_embedding(request: EmbeddingRequest):
    input = request.input

    # Process the input data
    batch_dict = embeddingTokenizer([input], max_length=512,
                                    padding=True, truncation=True, return_tensors='pt')
    outputs = embeddingModel(**batch_dict)
    embeddings = average_pool(outputs.last_hidden_state,
                              batch_dict['attention_mask'])

    # create response
    return {
        'embedding': embeddings[0].tolist()
    }


@app.post('/inference')
async def inference(request: InferenceRequest):
    input_text = request.input
    max_length = 256
    try:
        max_length = int(request.max_length)
        max_length = min(1024, max_length)
    except:
        pass

    # process request
    input_tokens = inferenceTokenizer.convert_ids_to_tokens(
        inferenceTokenizer.encode(input_text))

    results = inferenceTranslator.translate_batch(
        [input_tokens], max_input_length=0, max_decoding_length=max_length, num_hypotheses=1, repetition_penalty=1.3, sampling_topk=30, sampling_temperature=1.1, use_vmap=True, disable_unk=True)

    output_tokens = results[0].hypotheses[0]
    output_text = inferenceTokenizer.decode(
        inferenceTokenizer.convert_tokens_to_ids(output_tokens))

    # create response
    return {
        'generated_text': output_text
    }


@app.post('/tokens-count')
async def tokens_count(request: TokensCountRequest):
    input_text = request.input

    tokens = inferenceTokenizer.convert_ids_to_tokens(
        inferenceTokenizer.encode(input_text))

    # create response
    return {
        'tokens': tokens,
        'total': len(tokens)
    }