JHenzi commited on
Commit
5b9456f
·
verified ·
1 Parent(s): 55d74b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -90
app.py CHANGED
@@ -2,112 +2,48 @@ import gradio as gr
2
  import requests
3
  import json
4
  import os
5
- from screenshot import (
6
- before_prompt,
7
- prompt_to_generation,
8
- after_generation,
9
- js_save,
10
- js_load_script,
11
- )
12
- from spaces_info import description, examples, initial_prompt_value
13
 
14
  API_URL = os.getenv("API_URL")
15
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
16
 
 
 
17
 
18
  def query(payload):
19
- print(payload)
20
  response = requests.request("POST", API_URL, json=payload, headers={"Authorization": f"Bearer {HF_API_TOKEN}"})
21
- print(response)
22
  return json.loads(response.content.decode("utf-8"))
23
 
24
-
25
- def inference(input_sentence, max_length, sample_or_greedy, seed=42):
26
- if sample_or_greedy == "Sample":
27
- parameters = {
28
- "max_new_tokens": max_length,
29
- "top_p": 0.9,
30
- "do_sample": True,
31
- "seed": seed,
32
- "early_stopping": False,
33
- "length_penalty": 0.0,
34
- "eos_token_id": None,
35
- }
36
- else:
37
- parameters = {
38
- "max_new_tokens": max_length,
39
- "do_sample": False,
40
- "seed": seed,
41
- "early_stopping": False,
42
- "length_penalty": 0.0,
43
- "eos_token_id": None,
44
- }
45
-
46
- payload = {"inputs": input_sentence, "parameters": parameters,"options" : {"use_cache": False} }
47
-
48
  data = query(payload)
49
-
50
  if "error" in data:
51
- return (None, None, f"<span style='color:red'>ERROR: {data['error']} </span>")
52
-
53
- generation = data[0]["generated_text"].split(input_sentence, 1)[1]
54
- return (
55
- before_prompt
56
- + input_sentence
57
- + prompt_to_generation
58
- + generation
59
- + after_generation,
60
- data[0]["generated_text"],
61
- "",
62
- )
63
-
64
 
65
  if __name__ == "__main__":
66
  demo = gr.Blocks()
67
  with demo:
68
  with gr.Row():
69
- gr.Markdown(value=description)
70
- with gr.Row():
71
- with gr.Column():
72
- text = gr.Textbox(
73
- label="Input",
74
- value=" ", # should be set to " " when plugged into a real API
75
- )
76
- tokens = gr.Slider(1, 64, value=32, step=1, label="Tokens to generate")
77
- sampling = gr.Radio(
78
- ["Sample", "Greedy"], label="Sample or greedy", value="Sample"
79
- )
80
- sampling2 = gr.Radio(
81
- ["Sample 1", "Sample 2", "Sample 3", "Sample 4", "Sample 5"],
82
- value="Sample 1",
83
- label="Sample other generations (only work in 'Sample' mode)",
84
- type="index",
85
- )
86
-
87
- with gr.Row():
88
- submit = gr.Button("Submit")
89
- load_image = gr.Button("Generate Image")
90
- with gr.Column():
91
- text_error = gr.Markdown(label="Log information")
92
- text_out = gr.Textbox(label="Output")
93
- display_out = gr.HTML(label="Image")
94
- display_out.set_event_trigger(
95
- "load",
96
- fn=None,
97
- inputs=None,
98
- outputs=None,
99
- no_target=True,
100
- js=js_load_script,
101
- )
102
- with gr.Row():
103
- gr.Examples(examples=examples, inputs=[text, tokens, sampling, sampling2])
104
-
105
- submit.click(
106
- inference,
107
- inputs=[text, tokens, sampling, sampling2],
108
- outputs=[display_out, text_out, text_error],
109
  )
110
 
111
- load_image.click(fn=None, inputs=None, outputs=None, _js=js_save)
112
-
113
  demo.launch()
 
2
  import requests
3
  import json
4
  import os
 
 
 
 
 
 
 
 
5
 
6
  API_URL = os.getenv("API_URL")
7
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
8
 
9
+ # Global variable to store the generated text
10
+ generated_text = ""
11
 
12
  def query(payload):
 
13
  response = requests.request("POST", API_URL, json=payload, headers={"Authorization": f"Bearer {HF_API_TOKEN}"})
 
14
  return json.loads(response.content.decode("utf-8"))
15
 
16
+ def generate_and_append_text(max_length):
17
+ global generated_text
18
+ parameters = {
19
+ "max_new_tokens": max_length,
20
+ "top_p": 0.9,
21
+ "do_sample": True,
22
+ "seed": 42,
23
+ "early_stopping": False,
24
+ "length_penalty": 0.0,
25
+ "eos_token_id": None,
26
+ }
27
+ payload = {"inputs": generated_text, "parameters": parameters, "options": {"use_cache": False}}
 
 
 
 
 
 
 
 
 
 
 
 
28
  data = query(payload)
 
29
  if "error" in data:
30
+ return f"<span style='color:red'>ERROR: {data['error']} </span>"
31
+ new_text = data[0]["generated_text"].replace(generated_text, "").strip()
32
+ generated_text += " " + new_text
33
+ return generated_text
 
 
 
 
 
 
 
 
 
34
 
35
  if __name__ == "__main__":
36
  demo = gr.Blocks()
37
  with demo:
38
  with gr.Row():
39
+ generate_button = gr.Button("Generate Text")
40
+ tokens = gr.Slider(1, 64, value=32, step=1, label="Tokens to generate")
41
+ text_out = gr.Textbox(label="Generated Text")
42
+
43
+ generate_button.click(
44
+ generate_and_append_text,
45
+ inputs=tokens,
46
+ outputs=text_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  )
48
 
 
 
49
  demo.launch()