hysts HF staff commited on
Commit
80597e4
·
1 Parent(s): 9c00f5c

Modify to work in Spaces

Browse files
Files changed (5) hide show
  1. README.md +1 -0
  2. app.py +14 -22
  3. model.py +72 -26
  4. packages.txt +0 -1
  5. requirements.txt +3 -6
README.md CHANGED
@@ -5,6 +5,7 @@ colorFrom: pink
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 3.0.19
 
8
  app_file: app.py
9
  pinned: false
10
  ---
 
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 3.0.19
8
+ python_version: 3.9.13
9
  app_file: app.py
10
  pinned: false
11
  ---
app.py CHANGED
@@ -2,25 +2,16 @@
2
 
3
  from __future__ import annotations
4
 
5
- import argparse
6
-
7
  import gradio as gr
8
 
9
  from model import AppModel
10
 
11
- DESCRIPTION = '''# CogView2 (text2image)
12
-
13
- This is an unofficial demo for <a href="https://github.com/THUDM/CogView2">https://github.com/THUDM/CogView2</a>.
14
-
15
- [This Space](https://huggingface.co/spaces/chinhon/translation_eng2ch) is used for translation from English to Chinese.
16
  '''
17
-
18
-
19
- def parse_args() -> argparse.Namespace:
20
- parser = argparse.ArgumentParser()
21
- parser.add_argument('--only-first-stage', action='store_true')
22
- parser.add_argument('--share', action='store_true')
23
- return parser.parse_args()
24
 
25
 
26
  def set_example_text(example: list) -> dict:
@@ -28,8 +19,9 @@ def set_example_text(example: list) -> dict:
28
 
29
 
30
  def main():
31
- args = parse_args()
32
- model = AppModel(args.only_first_stage)
 
33
 
34
  with gr.Blocks(css='style.css') as demo:
35
  gr.Markdown(DESCRIPTION)
@@ -59,8 +51,8 @@ def main():
59
  label='Seed')
60
  only_first_stage = gr.Checkbox(
61
  label='Only First Stage',
62
- value=args.only_first_stage,
63
- visible=not args.only_first_stage)
64
  num_images = gr.Slider(1,
65
  16,
66
  step=1,
@@ -80,6 +72,9 @@ def main():
80
  with gr.TabItem('Output (Gallery)'):
81
  result_gallery = gr.Gallery(show_label=False)
82
 
 
 
 
83
  run_button.click(fn=model.run_with_translation,
84
  inputs=[
85
  text,
@@ -98,10 +93,7 @@ def main():
98
  inputs=examples,
99
  outputs=examples.components)
100
 
101
- demo.launch(
102
- enable_queue=True,
103
- share=args.share,
104
- )
105
 
106
 
107
  if __name__ == '__main__':
 
2
 
3
  from __future__ import annotations
4
 
 
 
5
  import gradio as gr
6
 
7
  from model import AppModel
8
 
9
+ DESCRIPTION = '# <a href="https://github.com/THUDM/CogView2">CogView2</a> (text2image)'
10
+ NOTES = '''
11
+ - This app is adapted from <a href="https://github.com/hysts/CogView2_demo">https://github.com/hysts/CogView2_demo</a>. It would be recommended to use the repo if you want to run the app yourself.
12
+ - [This Space](https://huggingface.co/spaces/chinhon/translation_eng2ch) is used for translation from English to Chinese.
 
13
  '''
14
+ FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=THUDM.CogView2" />'
 
 
 
 
 
 
15
 
16
 
17
  def set_example_text(example: list) -> dict:
 
19
 
20
 
21
  def main():
22
+ only_first_stage = True
23
+ max_inference_batch_size = 4
24
+ model = AppModel(max_inference_batch_size, only_first_stage)
25
 
26
  with gr.Blocks(css='style.css') as demo:
27
  gr.Markdown(DESCRIPTION)
 
51
  label='Seed')
52
  only_first_stage = gr.Checkbox(
53
  label='Only First Stage',
54
+ value=only_first_stage,
55
+ visible=not only_first_stage)
56
  num_images = gr.Slider(1,
57
  16,
58
  step=1,
 
72
  with gr.TabItem('Output (Gallery)'):
73
  result_gallery = gr.Gallery(show_label=False)
74
 
75
+ gr.Markdown(NOTES)
76
+ gr.Markdown(FOOTER)
77
+
78
  run_button.click(fn=model.run_with_translation,
79
  inputs=[
80
  text,
 
93
  inputs=examples,
94
  outputs=examples.components)
95
 
96
+ demo.launch(enable_queue=True)
 
 
 
97
 
98
 
99
  if __name__ == '__main__':
model.py CHANGED
@@ -1,19 +1,68 @@
1
- #This code is adapted from https://github.com/THUDM/CogView2/blob/4e55cce981eb94b9c8c1f19ba9f632fd3ee42ba8/cogview2_text2image.py
2
 
3
  from __future__ import annotations
4
 
5
  import argparse
6
  import functools
7
  import logging
 
8
  import pathlib
 
9
  import sys
10
  import time
 
11
  from typing import Any
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  import gradio as gr
14
  import numpy as np
15
  import torch
16
- from icetk import IceTokenizer
17
  from SwissArmyTransformer import get_args
18
  from SwissArmyTransformer.arguments import set_random_seed
19
  from SwissArmyTransformer.generation.autoregressive_sampling import \
@@ -38,7 +87,8 @@ logger.setLevel(logging.DEBUG)
38
  logger.propagate = False
39
  logger.addHandler(stream_handler)
40
 
41
- ICETK_MODEL_DIR = app_dir / 'icetk_models'
 
42
 
43
 
44
  def get_masks_and_position_ids_coglm(
@@ -140,11 +190,12 @@ def get_default_args() -> argparse.Namespace:
140
 
141
 
142
  class Model:
143
- def __init__(self, only_first_stage: bool = False):
 
 
144
  self.args = get_default_args()
145
  self.args.only_first_stage = only_first_stage
146
-
147
- self.tokenizer = self.load_tokenizer()
148
 
149
  self.model, self.args = self.load_model()
150
  self.strategy = self.load_strategy()
@@ -157,19 +208,6 @@ class Model:
157
  self.max_batch_size = self.args.max_inference_batch_size
158
  self.only_first_stage = self.args.only_first_stage
159
 
160
- def load_tokenizer(self) -> IceTokenizer:
161
- logger.info('--- load_tokenizer ---')
162
- start = time.perf_counter()
163
-
164
- tokenizer = IceTokenizer(ICETK_MODEL_DIR.as_posix())
165
- tokenizer.add_special_tokens(
166
- ['<start_of_image>', '<start_of_english>', '<start_of_chinese>'])
167
-
168
- elapsed = time.perf_counter() - start
169
- logger.info(f'Elapsed: {elapsed}')
170
- logger.info('--- done ---')
171
- return tokenizer
172
-
173
  def load_model(self) -> tuple[InferenceModel, argparse.Namespace]:
174
  logger.info('--- load_model ---')
175
  start = time.perf_counter()
@@ -185,7 +223,7 @@ class Model:
185
  logger.info('--- load_strategy ---')
186
  start = time.perf_counter()
187
 
188
- invalid_slices = [slice(self.tokenizer.num_image_tokens, None)]
189
  strategy = CoglmStrategy(invalid_slices,
190
  temperature=self.args.temp_all_gen,
191
  top_k=self.args.topk_gen,
@@ -213,6 +251,7 @@ class Model:
213
  logger.info('--- update_style ---')
214
  start = time.perf_counter()
215
 
 
216
  self.args = argparse.Namespace(**(vars(self.args) | get_recipe(style)))
217
  self.query_template = self.args.query_template
218
  logger.info(f'{self.query_template=}')
@@ -233,14 +272,21 @@ class Model:
233
 
234
  def run(self, text: str, style: str, seed: int, only_first_stage: bool,
235
  num: int) -> list[np.ndarray] | None:
 
 
 
 
236
  set_random_seed(seed)
237
  seq, txt_len = self.preprocess_text(text)
238
  if seq is None:
239
  return None
240
- self.update_style(style)
241
  self.only_first_stage = only_first_stage
242
  tokens = self.generate_tokens(seq, txt_len, num)
243
  res = self.generate_images(seq, txt_len, tokens)
 
 
 
 
244
  return res
245
 
246
  @torch.inference_mode()
@@ -251,7 +297,7 @@ class Model:
251
 
252
  text = self.query_template.format(text)
253
  logger.info(f'{text=}')
254
- seq = self.tokenizer.encode(text)
255
  logger.info(f'{len(seq)=}')
256
  if len(seq) > 110:
257
  logger.info('The input text is too long.')
@@ -319,7 +365,7 @@ class Model:
319
  if self.only_first_stage:
320
  for i in range(len(tokens)):
321
  seq = tokens[i]
322
- decoded_img = self.tokenizer.decode(image_ids=seq[-400:])
323
  decoded_img = torch.nn.functional.interpolate(decoded_img,
324
  size=(480, 480))
325
  decoded_img = self.postprocess(decoded_img[0])
@@ -327,7 +373,7 @@ class Model:
327
  else: # sr
328
  iter_tokens = self.srg.sr_base(tokens[:, -400:], seq[:txt_len])
329
  for seq in iter_tokens:
330
- decoded_img = self.tokenizer.decode(image_ids=seq[-3600:])
331
  decoded_img = torch.nn.functional.interpolate(decoded_img,
332
  size=(480, 480))
333
  decoded_img = self.postprocess(decoded_img[0])
@@ -340,8 +386,8 @@ class Model:
340
 
341
 
342
  class AppModel(Model):
343
- def __init__(self, only_first_stage: bool):
344
- super().__init__(only_first_stage)
345
  self.translator = gr.Interface.load(
346
  'spaces/chinhon/translation_eng2ch')
347
 
 
1
+ # This code is adapted from https://github.com/THUDM/CogView2/blob/4e55cce981eb94b9c8c1f19ba9f632fd3ee42ba8/cogview2_text2image.py
2
 
3
  from __future__ import annotations
4
 
5
  import argparse
6
  import functools
7
  import logging
8
+ import os
9
  import pathlib
10
+ import subprocess
11
  import sys
12
  import time
13
+ import zipfile
14
  from typing import Any
15
 
16
+ if os.getenv('SYSTEM') == 'spaces':
17
+ subprocess.run('pip install icetk==0.0.3'.split())
18
+ subprocess.run('pip install SwissArmyTransformer==0.2.4'.split())
19
+ subprocess.run(
20
+ 'pip install git+https://github.com/Sleepychord/Image-Local-Attention@43fee31'
21
+ .split())
22
+ subprocess.run('git clone https://github.com/NVIDIA/apex'.split())
23
+ subprocess.run('git checkout 1403c21'.split(), cwd='apex')
24
+ subprocess.run(
25
+ 'pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./'
26
+ .split(),
27
+ cwd='apex')
28
+ subprocess.run('rm -rf apex'.split())
29
+ with open('patch') as f:
30
+ subprocess.run('patch -p1'.split(), cwd='CogView2', stdin=f)
31
+
32
+ from huggingface_hub import hf_hub_download
33
+
34
+ def download_and_extract_icetk_models() -> None:
35
+ icetk_model_dir = pathlib.Path('/home/user/.icetk_models')
36
+ icetk_model_dir.mkdir()
37
+ path = hf_hub_download('THUDM/icetk',
38
+ 'models.zip',
39
+ use_auth_token=os.getenv('HF_TOKEN'))
40
+ with zipfile.ZipFile(path) as f:
41
+ f.extractall(path=icetk_model_dir.as_posix())
42
+
43
+ def download_and_extract_cogview2_models(name: str) -> None:
44
+ path = hf_hub_download('THUDM/CogView2',
45
+ name,
46
+ use_auth_token=os.getenv('HF_TOKEN'))
47
+ with zipfile.ZipFile(path) as f:
48
+ f.extractall()
49
+ os.remove(path)
50
+
51
+ download_and_extract_icetk_models()
52
+ names = [
53
+ 'coglm.zip',
54
+ 'cogview2-dsr.zip',
55
+ #'cogview2-itersr.zip',
56
+ ]
57
+ for name in names:
58
+ download_and_extract_cogview2_models(name)
59
+
60
+ os.environ['SAT_HOME'] = '/home/user/app/sharefs/cogview-new'
61
+
62
  import gradio as gr
63
  import numpy as np
64
  import torch
65
+ from icetk import icetk as tokenizer
66
  from SwissArmyTransformer import get_args
67
  from SwissArmyTransformer.arguments import set_random_seed
68
  from SwissArmyTransformer.generation.autoregressive_sampling import \
 
87
  logger.propagate = False
88
  logger.addHandler(stream_handler)
89
 
90
+ tokenizer.add_special_tokens(
91
+ ['<start_of_image>', '<start_of_english>', '<start_of_chinese>'])
92
 
93
 
94
  def get_masks_and_position_ids_coglm(
 
190
 
191
 
192
  class Model:
193
+ def __init__(self,
194
+ max_inference_batch_size: int,
195
+ only_first_stage: bool = False):
196
  self.args = get_default_args()
197
  self.args.only_first_stage = only_first_stage
198
+ self.args.max_inference_batch_size = max_inference_batch_size
 
199
 
200
  self.model, self.args = self.load_model()
201
  self.strategy = self.load_strategy()
 
208
  self.max_batch_size = self.args.max_inference_batch_size
209
  self.only_first_stage = self.args.only_first_stage
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  def load_model(self) -> tuple[InferenceModel, argparse.Namespace]:
212
  logger.info('--- load_model ---')
213
  start = time.perf_counter()
 
223
  logger.info('--- load_strategy ---')
224
  start = time.perf_counter()
225
 
226
+ invalid_slices = [slice(tokenizer.num_image_tokens, None)]
227
  strategy = CoglmStrategy(invalid_slices,
228
  temperature=self.args.temp_all_gen,
229
  top_k=self.args.topk_gen,
 
251
  logger.info('--- update_style ---')
252
  start = time.perf_counter()
253
 
254
+ self.style = style
255
  self.args = argparse.Namespace(**(vars(self.args) | get_recipe(style)))
256
  self.query_template = self.args.query_template
257
  logger.info(f'{self.query_template=}')
 
272
 
273
  def run(self, text: str, style: str, seed: int, only_first_stage: bool,
274
  num: int) -> list[np.ndarray] | None:
275
+ logger.info('==================== run ====================')
276
+ start = time.perf_counter()
277
+
278
+ self.update_style(style)
279
  set_random_seed(seed)
280
  seq, txt_len = self.preprocess_text(text)
281
  if seq is None:
282
  return None
 
283
  self.only_first_stage = only_first_stage
284
  tokens = self.generate_tokens(seq, txt_len, num)
285
  res = self.generate_images(seq, txt_len, tokens)
286
+
287
+ elapsed = time.perf_counter() - start
288
+ logger.info(f'Elapsed: {elapsed}')
289
+ logger.info('==================== done ====================')
290
  return res
291
 
292
  @torch.inference_mode()
 
297
 
298
  text = self.query_template.format(text)
299
  logger.info(f'{text=}')
300
+ seq = tokenizer.encode(text)
301
  logger.info(f'{len(seq)=}')
302
  if len(seq) > 110:
303
  logger.info('The input text is too long.')
 
365
  if self.only_first_stage:
366
  for i in range(len(tokens)):
367
  seq = tokens[i]
368
+ decoded_img = tokenizer.decode(image_ids=seq[-400:])
369
  decoded_img = torch.nn.functional.interpolate(decoded_img,
370
  size=(480, 480))
371
  decoded_img = self.postprocess(decoded_img[0])
 
373
  else: # sr
374
  iter_tokens = self.srg.sr_base(tokens[:, -400:], seq[:txt_len])
375
  for seq in iter_tokens:
376
+ decoded_img = tokenizer.decode(image_ids=seq[-3600:])
377
  decoded_img = torch.nn.functional.interpolate(decoded_img,
378
  size=(480, 480))
379
  decoded_img = self.postprocess(decoded_img[0])
 
386
 
387
 
388
  class AppModel(Model):
389
+ def __init__(self, max_inference_batch_size: int, only_first_stage: bool):
390
+ super().__init__(max_inference_batch_size, only_first_stage)
391
  self.translator = gr.Interface.load(
392
  'spaces/chinhon/translation_eng2ch')
393
 
packages.txt DELETED
@@ -1 +0,0 @@
1
- p7zip-full
 
 
requirements.txt CHANGED
@@ -1,7 +1,4 @@
1
- git+https://github.com/Sleepychord/Image-Local-Attention@43fee31
2
- gradio==3.0.17
3
- icetk==0.0.3
4
  numpy==1.22.4
5
- SwissArmyTransformer==0.2.4
6
- torch==1.11.0
7
- torchvision==0.12.0
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu113
 
 
2
  numpy==1.22.4
3
+ torch==1.11.0+cu113
4
+ torchvision==0.12.0+cu113