Cahya Wirawan commited on
Commit
4964cc6
1 Parent(s): 7191ca8

add indochat

Browse files
app/{web_socket.py → api.py} RENAMED
@@ -59,3 +59,72 @@ async def websocket_endpoint(websocket: WebSocket):
59
  data = await websocket.receive_text()
60
  await websocket.send_text(f"Message text was: {data}")
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  data = await websocket.receive_text()
60
  await websocket.send_text(f"Message text was: {data}")
61
 
62
+
63
+ @app.post("/api/indochat/v1")
64
+ async def indochat(
65
+ text: str = Form(default="", description="The Prompt"),
66
+ max_length: int = Form(default=250, description="Maximal length of the generated text"),
67
+ do_sample: bool = Form(default=True, description="Whether to use sampling; use greedy decoding otherwise"),
68
+ top_k: int = Form(default=50, description="The number of highest probability vocabulary tokens to keep "
69
+ "for top-k-filtering"),
70
+ top_p: float = Form(default=0.95, description="If set to float < 1, only the most probable tokens with "
71
+ "probabilities that add up to top_p or higher are kept "
72
+ "for generation"),
73
+ temperature: float = Form(default=1.0, description="The Temperature of the softmax distribution"),
74
+ penalty_alpha: float = Form(default=0.6, description="Penalty alpha"),
75
+ repetition_penalty: float = Form(default=1.0, description="Repetition penalty"),
76
+ seed: int = Form(default=42, description="Random Seed"),
77
+ max_time: float = Form(default=60.0, description="Maximal time in seconds to generate the text")
78
+ ):
79
+ set_seed(seed)
80
+ if repetition_penalty == 0.0:
81
+ min_penalty = 1.05
82
+ max_penalty = 1.5
83
+ repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8)
84
+ prompt = f"User: {text}\nAssistant: "
85
+ input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
86
+ model.eval()
87
+ print("Generating text...")
88
+ print(f"max_length: {max_length}, do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, "
89
+ f"temperature: {temperature}, repetition_penalty: {repetition_penalty}, penalty_alpha: {penalty_alpha}")
90
+ time_start = time.time()
91
+ sample_outputs = model.generate(input_ids,
92
+ penalty_alpha=penalty_alpha,
93
+ do_sample=do_sample,
94
+ min_length=200,
95
+ max_length=max_length,
96
+ top_k=top_k,
97
+ top_p=top_p,
98
+ temperature=temperature,
99
+ repetition_penalty=repetition_penalty,
100
+ num_return_sequences=1,
101
+ max_time=max_time
102
+ )
103
+ result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
104
+ # result = result[len(prompt) + 1:]
105
+ time_end = time.time()
106
+ time_diff = time_end - time_start
107
+ print(f"result:\n{result}")
108
+ generated_text = result
109
+ return {"generated_text": generated_text, "processing_time": time_diff}
110
+
111
+
112
+ def get_text_generator(model_name: str, device: str = "cpu"):
113
+ hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
114
+ print(f"hf_auth_token: {hf_auth_token}")
115
+ print(f"Loading model with device: {device}...")
116
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token)
117
+ model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id,
118
+ use_auth_token=hf_auth_token)
119
+ model.to(device)
120
+ print("Model loaded")
121
+ return model, tokenizer
122
+
123
+
124
+ def get_config():
125
+ return json.load(open("config.json", "r"))
126
+
127
+
128
+ config = get_config()
129
+ device = "cuda" if torch.cuda.is_available() else "cpu"
130
+ model, tokenizer = get_text_generator(model_name=config["model_name"], device=device)
app/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "model_name": "cahya/indochat-tiny"
3
+ }
app/start.sh CHANGED
@@ -3,21 +3,16 @@ set -e
3
 
4
  cd /home/user/app
5
 
6
- id
7
- ls -ld /var/log/nginx/ /var/lib/nginx/ /run
8
- ls -la /
9
- ls -la ~
10
-
11
  nginx
12
 
13
  python whisper.py&
14
 
15
  if [ "$DEBUG" = true ] ; then
16
  echo 'Debugging - ON'
17
- uvicorn web_socket:app --host 0.0.0.0 --port 7880 --reload
18
  else
19
  echo 'Debugging - OFF'
20
- uvicorn web_socket:app --host 0.0.0.0 --port 7880
21
  echo $?
22
  echo END
23
  fi
 
3
 
4
  cd /home/user/app
5
 
 
 
 
 
 
6
  nginx
7
 
8
  python whisper.py&
9
 
10
  if [ "$DEBUG" = true ] ; then
11
  echo 'Debugging - ON'
12
+ uvicorn api:app --host 0.0.0.0 --port 7880 --reload
13
  else
14
  echo 'Debugging - OFF'
15
+ uvicorn api:app --host 0.0.0.0 --port 7880
16
  echo $?
17
  echo END
18
  fi