joaogante HF staff commited on
Commit
a1a543e
·
1 Parent(s): f30163f
Files changed (2) hide show
  1. app.py +89 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Thread
2
+
3
+ import torch
4
+ import gradio as gr
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
6
+
7
+ model_id = "declare-lab/flan-alpaca-large"
8
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ print("Running on device:", torch_device)
10
+ print("CPU threads:", torch.get_num_threads())
11
+
12
+
13
+ if torch_device == "cuda":
14
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
15
+ else:
16
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
17
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
18
+
19
+
20
+ def run_generation(user_text, top_p, temperature, top_k, max_new_tokens):
21
+ # Get the model and tokenizer, and tokenize the user text.
22
+ model_inputs = tokenizer([user_text], return_tensors="pt").to(torch_device)
23
+
24
+ # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
25
+ # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
26
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
27
+ generate_kwargs = dict(
28
+ model_inputs,
29
+ streamer=streamer,
30
+ max_new_tokens=max_new_tokens,
31
+ do_sample=True,
32
+ top_p=top_p,
33
+ temperature=float(temperature),
34
+ top_k=top_k
35
+ )
36
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
37
+ t.start()
38
+
39
+ # Pull the generated text from the streamer, and update the model output.
40
+ model_output = ""
41
+ for new_text in streamer:
42
+ model_output += new_text
43
+ yield model_output
44
+ return model_output
45
+
46
+
47
+ def reset_textbox():
48
+ return gr.update(value='')
49
+
50
+
51
+ with gr.Blocks() as demo:
52
+ duplicate_link = "https://huggingface.co/spaces/joaogante/transformers_streaming?duplicate=true"
53
+ gr.Markdown(
54
+ "# 🤗 Transformers 🔥Streaming🔥 on Gradio\n"
55
+ "This demo showcases the use of the "
56
+ "[streaming feature](https://huggingface.co/docs/transformers/main/en/generation_strategies#streaming) "
57
+ "of 🤗 Transformers with Gradio to generate text in real-time. It uses "
58
+ f"[{model_id}](https://huggingface.co/{model_id}) and the Spaces free compute tier.\n\n"
59
+ f"Feel free to [duplicate this Space]({duplicate_link}) to try your own models or use this space as a "
60
+ "template! 💛"
61
+ )
62
+
63
+ with gr.Row():
64
+ with gr.Column(scale=4):
65
+ user_text = gr.Textbox(
66
+ placeholder="Write an email about an alpaca that likes flan",
67
+ label="User input"
68
+ )
69
+ model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
70
+ button_submit = gr.Button(value="Submit")
71
+
72
+ with gr.Column(scale=1):
73
+ max_new_tokens = gr.Slider(
74
+ minimum=1, maximum=1000, value=250, step=1, interactive=True, label="Max New Tokens",
75
+ )
76
+ top_p = gr.Slider(
77
+ minimum=0.05, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p (nucleus sampling)",
78
+ )
79
+ top_k = gr.Slider(
80
+ minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k",
81
+ )
82
+ temperature = gr.Slider(
83
+ minimum=0.1, maximum=5.0, value=0.8, step=0.1, interactive=True, label="Temperature",
84
+ )
85
+
86
+ user_text.submit(run_generation, [user_text, top_p, temperature, top_k, max_new_tokens], model_output)
87
+ button_submit.click(run_generation, [user_text, top_p, temperature, top_k, max_new_tokens], model_output)
88
+
89
+ demo.queue(max_size=32).launch(enable_queue=True)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ accelerate
2
+ bitsandbytes
3
+ torch
4
+ git+https://github.com/huggingface/transformers.git # transformers from main (TextIteratorStreamer will be added in v4.28)