Spaces:
Runtime error
Runtime error
Cahya Wirawan
commited on
Commit
•
4964cc6
1
Parent(s):
7191ca8
add indochat
Browse files- app/{web_socket.py → api.py} +69 -0
- app/config.json +3 -0
- app/start.sh +2 -7
app/{web_socket.py → api.py}
RENAMED
@@ -59,3 +59,72 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
59 |
data = await websocket.receive_text()
|
60 |
await websocket.send_text(f"Message text was: {data}")
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
data = await websocket.receive_text()
|
60 |
await websocket.send_text(f"Message text was: {data}")
|
61 |
|
62 |
+
|
63 |
+
@app.post("/api/indochat/v1")
|
64 |
+
async def indochat(
|
65 |
+
text: str = Form(default="", description="The Prompt"),
|
66 |
+
max_length: int = Form(default=250, description="Maximal length of the generated text"),
|
67 |
+
do_sample: bool = Form(default=True, description="Whether to use sampling; use greedy decoding otherwise"),
|
68 |
+
top_k: int = Form(default=50, description="The number of highest probability vocabulary tokens to keep "
|
69 |
+
"for top-k-filtering"),
|
70 |
+
top_p: float = Form(default=0.95, description="If set to float < 1, only the most probable tokens with "
|
71 |
+
"probabilities that add up to top_p or higher are kept "
|
72 |
+
"for generation"),
|
73 |
+
temperature: float = Form(default=1.0, description="The Temperature of the softmax distribution"),
|
74 |
+
penalty_alpha: float = Form(default=0.6, description="Penalty alpha"),
|
75 |
+
repetition_penalty: float = Form(default=1.0, description="Repetition penalty"),
|
76 |
+
seed: int = Form(default=42, description="Random Seed"),
|
77 |
+
max_time: float = Form(default=60.0, description="Maximal time in seconds to generate the text")
|
78 |
+
):
|
79 |
+
set_seed(seed)
|
80 |
+
if repetition_penalty == 0.0:
|
81 |
+
min_penalty = 1.05
|
82 |
+
max_penalty = 1.5
|
83 |
+
repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8)
|
84 |
+
prompt = f"User: {text}\nAssistant: "
|
85 |
+
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
|
86 |
+
model.eval()
|
87 |
+
print("Generating text...")
|
88 |
+
print(f"max_length: {max_length}, do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, "
|
89 |
+
f"temperature: {temperature}, repetition_penalty: {repetition_penalty}, penalty_alpha: {penalty_alpha}")
|
90 |
+
time_start = time.time()
|
91 |
+
sample_outputs = model.generate(input_ids,
|
92 |
+
penalty_alpha=penalty_alpha,
|
93 |
+
do_sample=do_sample,
|
94 |
+
min_length=200,
|
95 |
+
max_length=max_length,
|
96 |
+
top_k=top_k,
|
97 |
+
top_p=top_p,
|
98 |
+
temperature=temperature,
|
99 |
+
repetition_penalty=repetition_penalty,
|
100 |
+
num_return_sequences=1,
|
101 |
+
max_time=max_time
|
102 |
+
)
|
103 |
+
result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
|
104 |
+
# result = result[len(prompt) + 1:]
|
105 |
+
time_end = time.time()
|
106 |
+
time_diff = time_end - time_start
|
107 |
+
print(f"result:\n{result}")
|
108 |
+
generated_text = result
|
109 |
+
return {"generated_text": generated_text, "processing_time": time_diff}
|
110 |
+
|
111 |
+
|
112 |
+
def get_text_generator(model_name: str, device: str = "cpu"):
|
113 |
+
hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
|
114 |
+
print(f"hf_auth_token: {hf_auth_token}")
|
115 |
+
print(f"Loading model with device: {device}...")
|
116 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token)
|
117 |
+
model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id,
|
118 |
+
use_auth_token=hf_auth_token)
|
119 |
+
model.to(device)
|
120 |
+
print("Model loaded")
|
121 |
+
return model, tokenizer
|
122 |
+
|
123 |
+
|
124 |
+
def get_config():
|
125 |
+
return json.load(open("config.json", "r"))
|
126 |
+
|
127 |
+
|
128 |
+
config = get_config()
|
129 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
130 |
+
model, tokenizer = get_text_generator(model_name=config["model_name"], device=device)
|
app/config.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_name": "cahya/indochat-tiny"
|
3 |
+
}
|
app/start.sh
CHANGED
@@ -3,21 +3,16 @@ set -e
|
|
3 |
|
4 |
cd /home/user/app
|
5 |
|
6 |
-
id
|
7 |
-
ls -ld /var/log/nginx/ /var/lib/nginx/ /run
|
8 |
-
ls -la /
|
9 |
-
ls -la ~
|
10 |
-
|
11 |
nginx
|
12 |
|
13 |
python whisper.py&
|
14 |
|
15 |
if [ "$DEBUG" = true ] ; then
|
16 |
echo 'Debugging - ON'
|
17 |
-
uvicorn
|
18 |
else
|
19 |
echo 'Debugging - OFF'
|
20 |
-
uvicorn
|
21 |
echo $?
|
22 |
echo END
|
23 |
fi
|
|
|
3 |
|
4 |
cd /home/user/app
|
5 |
|
|
|
|
|
|
|
|
|
|
|
6 |
nginx
|
7 |
|
8 |
python whisper.py&
|
9 |
|
10 |
if [ "$DEBUG" = true ] ; then
|
11 |
echo 'Debugging - ON'
|
12 |
+
uvicorn api:app --host 0.0.0.0 --port 7880 --reload
|
13 |
else
|
14 |
echo 'Debugging - OFF'
|
15 |
+
uvicorn api:app --host 0.0.0.0 --port 7880
|
16 |
echo $?
|
17 |
echo END
|
18 |
fi
|