pszemraj joaogante HF staff commited on
Commit
09d15e8
·
0 Parent(s):

Duplicate from joaogante/transformers_streaming

Browse files

Co-authored-by: Joao Gante <[email protected]>

Files changed (5) hide show
  1. .gitattributes +34 -0
  2. .gitignore +169 -0
  3. README.md +13 -0
  4. app.py +89 -0
  5. requirements.txt +4 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Initially taken from Github's Python gitignore file
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # tests and logs
12
+ tests/fixtures/cached_*_text.txt
13
+ logs/
14
+ lightning_logs/
15
+ lang_code_data/
16
+
17
+ # Distribution / packaging
18
+ .Python
19
+ build/
20
+ develop-eggs/
21
+ dist/
22
+ downloads/
23
+ eggs/
24
+ .eggs/
25
+ lib/
26
+ lib64/
27
+ parts/
28
+ sdist/
29
+ var/
30
+ wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .nox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ .python-version
90
+
91
+ # celery beat schedule file
92
+ celerybeat-schedule
93
+
94
+ # SageMath parsed files
95
+ *.sage.py
96
+
97
+ # Environments
98
+ .env
99
+ .venv
100
+ env/
101
+ venv/
102
+ ENV/
103
+ env.bak/
104
+ venv.bak/
105
+
106
+ # Spyder project settings
107
+ .spyderproject
108
+ .spyproject
109
+
110
+ # Rope project settings
111
+ .ropeproject
112
+
113
+ # mkdocs documentation
114
+ /site
115
+
116
+ # mypy
117
+ .mypy_cache/
118
+ .dmypy.json
119
+ dmypy.json
120
+
121
+ # Pyre type checker
122
+ .pyre/
123
+
124
+ # vscode
125
+ .vs
126
+ .vscode
127
+
128
+ # Pycharm
129
+ .idea
130
+
131
+ # TF code
132
+ tensorflow_code
133
+
134
+ # Models
135
+ proc_data
136
+
137
+ # examples
138
+ runs
139
+ /runs_old
140
+ /wandb
141
+ /examples/runs
142
+ /examples/**/*.args
143
+ /examples/rag/sweep
144
+
145
+ # data
146
+ /data
147
+ serialization_dir
148
+
149
+ # emacs
150
+ *.*~
151
+ debug.env
152
+
153
+ # vim
154
+ .*.swp
155
+
156
+ #ctags
157
+ tags
158
+
159
+ # pre-commit
160
+ .pre-commit*
161
+
162
+ # .lock
163
+ *.lock
164
+
165
+ # DS_Store (MacOS)
166
+ .DS_Store
167
+
168
+ # ruff
169
+ .ruff_cache
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Chatbot Transformers Streaming
3
+ emoji: 👀
4
+ colorFrom: gray
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.23.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: joaogante/transformers_streaming
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Thread
2
+
3
+ import torch
4
+ import gradio as gr
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
6
+
7
+ model_id = "declare-lab/flan-alpaca-large"
8
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ print("Running on device:", torch_device)
10
+ print("CPU threads:", torch.get_num_threads())
11
+
12
+
13
+ if torch_device == "cuda":
14
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
15
+ else:
16
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
17
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
18
+
19
+
20
+ def run_generation(user_text, top_p, temperature, top_k, max_new_tokens):
21
+ # Get the model and tokenizer, and tokenize the user text.
22
+ model_inputs = tokenizer([user_text], return_tensors="pt").to(torch_device)
23
+
24
+ # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
25
+ # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
26
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
27
+ generate_kwargs = dict(
28
+ model_inputs,
29
+ streamer=streamer,
30
+ max_new_tokens=max_new_tokens,
31
+ do_sample=True,
32
+ top_p=top_p,
33
+ temperature=float(temperature),
34
+ top_k=top_k
35
+ )
36
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
37
+ t.start()
38
+
39
+ # Pull the generated text from the streamer, and update the model output.
40
+ model_output = ""
41
+ for new_text in streamer:
42
+ model_output += new_text
43
+ yield model_output
44
+ return model_output
45
+
46
+
47
+ def reset_textbox():
48
+ return gr.update(value='')
49
+
50
+
51
+ with gr.Blocks() as demo:
52
+ duplicate_link = "https://huggingface.co/spaces/joaogante/transformers_streaming?duplicate=true"
53
+ gr.Markdown(
54
+ "# 🤗 Transformers 🔥Streaming🔥 on Gradio\n"
55
+ "This demo showcases the use of the "
56
+ "[streaming feature](https://huggingface.co/docs/transformers/main/en/generation_strategies#streaming) "
57
+ "of 🤗 Transformers with Gradio to generate text in real-time. It uses "
58
+ f"[{model_id}](https://huggingface.co/{model_id}) and the Spaces free compute tier.\n\n"
59
+ f"Feel free to [duplicate this Space]({duplicate_link}) to try your own models or use this space as a "
60
+ "template! 💛"
61
+ )
62
+
63
+ with gr.Row():
64
+ with gr.Column(scale=4):
65
+ user_text = gr.Textbox(
66
+ placeholder="Write an email about an alpaca that likes flan",
67
+ label="User input"
68
+ )
69
+ model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
70
+ button_submit = gr.Button(value="Submit")
71
+
72
+ with gr.Column(scale=1):
73
+ max_new_tokens = gr.Slider(
74
+ minimum=1, maximum=1000, value=250, step=1, interactive=True, label="Max New Tokens",
75
+ )
76
+ top_p = gr.Slider(
77
+ minimum=0.05, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p (nucleus sampling)",
78
+ )
79
+ top_k = gr.Slider(
80
+ minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k",
81
+ )
82
+ temperature = gr.Slider(
83
+ minimum=0.1, maximum=5.0, value=0.8, step=0.1, interactive=True, label="Temperature",
84
+ )
85
+
86
+ user_text.submit(run_generation, [user_text, top_p, temperature, top_k, max_new_tokens], model_output)
87
+ button_submit.click(run_generation, [user_text, top_p, temperature, top_k, max_new_tokens], model_output)
88
+
89
+ demo.queue(max_size=32).launch(enable_queue=True)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ accelerate
2
+ bitsandbytes
3
+ torch
4
+ git+https://github.com/huggingface/transformers.git # transformers from main (TextIteratorStreamer will be added in v4.28)