cahya commited on
Commit
a4b0cb5
·
1 Parent(s): 99167bb

fix cuda device and penalti alpha

Browse files
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -3,24 +3,27 @@ import gradio as gr
3
  from transformers import pipeline
4
  import os
5
 
6
- device = "cuda" if torch.cuda.is_available() else "cpu"
7
  HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN")
8
  text_generation_model = "cahya/indochat-tiny"
9
  text_generation = pipeline("text-generation", text_generation_model, use_auth_token=HF_AUTH_TOKEN, device=device)
10
 
11
 
12
- def get_answer(user_input, decoding_methods, top_k, top_p, temperature, repetition_penalty, penalty_alpha):
13
  if decoding_methods == "Beam Search":
14
  do_sample = False
 
15
  elif decoding_methods == "Sampling":
16
  do_sample = True
 
17
  else:
18
  do_sample = False
19
  print(user_input, decoding_methods, do_sample, top_k, top_p, temperature, repetition_penalty, penalty_alpha)
20
  prompt = f"User: {user_input}\nAssistant: "
21
  generated_text = text_generation(f"{prompt}", min_length=50, max_length=200, num_return_sequences=1,
22
- do_sample=do_sample, top_k=top_k, top_p=top_p, temperature=temperature,
23
- repetition_penalty=repetition_penalty)#, penalty_alpha=penalty_alpha)
 
24
  answer = generated_text[0]["generated_text"]
25
  answer_without_prompt = answer[len(prompt)+1:]
26
  return answer_without_prompt
@@ -28,8 +31,7 @@ def get_answer(user_input, decoding_methods, top_k, top_p, temperature, repetiti
28
 
29
  with gr.Blocks() as demo:
30
  with gr.Row():
31
- gr.Markdown(
32
- "## IndoChat")
33
  with gr.Row():
34
  with gr.Column():
35
  user_input = gr.inputs.Textbox(placeholder="",
@@ -37,8 +39,10 @@ with gr.Blocks() as demo:
37
  default="Bagaimana cara mendidik anak supaya tidak berbohong?")
38
  decoding_methods = gr.inputs.Dropdown(["Beam Search", "Sampling", "Contrastive Search"],
39
  default="Sampling")
40
- top_k = gr.inputs.Slider(label="Top K: The number of highest probability vocabulary tokens to keep",
41
- default=40, maximum=50, minimum=1, step=1)
 
 
42
  top_p = gr.inputs.Slider(label="Top P", default=0.9, step=0.05, minimum=0.1, maximum=1.0)
43
  temperature = gr.inputs.Slider(label="Temperature", default=0.5, step=0.05, minimum=0.1, maximum=1.0)
44
  repetition_penalty = gr.inputs.Slider(label="Repetition Penalty", default=1.1, step=0.05, minimum=1.0, maximum=2.0)
@@ -50,7 +54,7 @@ with gr.Blocks() as demo:
50
  with gr.Row():
51
  gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=cahya_indochat)")
52
 
53
- button_generate_story.click(get_answer, inputs=[user_input, decoding_methods, top_k, top_p, temperature,
54
  repetition_penalty, penalty_alpha], outputs=[generated_answer])
55
 
56
  demo.launch(enable_queue=False)
 
3
  from transformers import pipeline
4
  import os
5
 
6
+ device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
7
  HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN")
8
  text_generation_model = "cahya/indochat-tiny"
9
  text_generation = pipeline("text-generation", text_generation_model, use_auth_token=HF_AUTH_TOKEN, device=device)
10
 
11
 
12
+ def get_answer(user_input, decoding_methods, num_beams, top_k, top_p, temperature, repetition_penalty, penalty_alpha):
13
  if decoding_methods == "Beam Search":
14
  do_sample = False
15
+ penalty_alpha = 0
16
  elif decoding_methods == "Sampling":
17
  do_sample = True
18
+ penalty_alpha = 0
19
  else:
20
  do_sample = False
21
  print(user_input, decoding_methods, do_sample, top_k, top_p, temperature, repetition_penalty, penalty_alpha)
22
  prompt = f"User: {user_input}\nAssistant: "
23
  generated_text = text_generation(f"{prompt}", min_length=50, max_length=200, num_return_sequences=1,
24
+ num_beams=num_beams, do_sample=do_sample, top_k=top_k, top_p=top_p,
25
+ temperature=temperature, repetition_penalty=repetition_penalty,
26
+ penalty_alpha=penalty_alpha)
27
  answer = generated_text[0]["generated_text"]
28
  answer_without_prompt = answer[len(prompt)+1:]
29
  return answer_without_prompt
 
31
 
32
  with gr.Blocks() as demo:
33
  with gr.Row():
34
+ gr.Markdown("## IndoChat")
 
35
  with gr.Row():
36
  with gr.Column():
37
  user_input = gr.inputs.Textbox(placeholder="",
 
39
  default="Bagaimana cara mendidik anak supaya tidak berbohong?")
40
  decoding_methods = gr.inputs.Dropdown(["Beam Search", "Sampling", "Contrastive Search"],
41
  default="Sampling")
42
+ num_beams = gr.inputs.Slider(label="Number of beams for beam search",
43
+ default=1, minimum=1, maximum=10, step=1)
44
+ top_k = gr.inputs.Slider(label="Top K",
45
+ default=30, maximum=50, minimum=1, step=1)
46
  top_p = gr.inputs.Slider(label="Top P", default=0.9, step=0.05, minimum=0.1, maximum=1.0)
47
  temperature = gr.inputs.Slider(label="Temperature", default=0.5, step=0.05, minimum=0.1, maximum=1.0)
48
  repetition_penalty = gr.inputs.Slider(label="Repetition Penalty", default=1.1, step=0.05, minimum=1.0, maximum=2.0)
 
54
  with gr.Row():
55
  gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=cahya_indochat)")
56
 
57
+ button_generate_story.click(get_answer, inputs=[user_input, decoding_methods, num_beams, top_k, top_p, temperature,
58
  repetition_penalty, penalty_alpha], outputs=[generated_answer])
59
 
60
  demo.launch(enable_queue=False)