Shitao commited on
Commit
44bc074
·
1 Parent(s): 200a130
OmniGen/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (325 Bytes). View file
 
OmniGen/__pycache__/model.cpython-310.pyc ADDED
Binary file (12.5 kB). View file
 
OmniGen/__pycache__/pipeline.cpython-310.pyc ADDED
Binary file (8.5 kB). View file
 
OmniGen/__pycache__/processor.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
OmniGen/__pycache__/scheduler.cpython-310.pyc ADDED
Binary file (2.75 kB). View file
 
OmniGen/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (3.95 kB). View file
 
OmniGen/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.52 kB). View file
 
OmniGen/model.py CHANGED
@@ -5,7 +5,10 @@ import torch.nn as nn
5
  import numpy as np
6
  import math
7
  from typing import Dict
 
 
8
  from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
 
9
 
10
  from OmniGen.transformer import Phi3Config, Phi3Transformer
11
 
@@ -145,7 +148,7 @@ class PatchEmbedMR(nn.Module):
145
  return x
146
 
147
 
148
- class OmniGen(nn.Module):
149
  """
150
  Diffusion model with a Transformer backbone.
151
  """
@@ -191,7 +194,7 @@ class OmniGen(nn.Module):
191
  ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
192
  config = Phi3Config.from_pretrained(model_name)
193
  model = cls(config)
194
- ckpt = torch.load(os.path.join(model_name, 'model.pt'))
195
  model.load_state_dict(ckpt)
196
  return model
197
 
@@ -304,7 +307,7 @@ class OmniGen(nn.Module):
304
  return latents, num_tokens, shapes
305
 
306
 
307
- def forward(self, x, timestep, text_ids, pixel_values, image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None):
308
  """
309
 
310
  """
@@ -312,16 +315,16 @@ class OmniGen(nn.Module):
312
  x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
313
  time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
314
 
315
- if pixel_values is not None:
316
- input_latents, _, _ = self.patch_multiple_resolutions(pixel_values, is_input_images=True)
317
- if text_ids is not None:
318
- condition_embeds = self.llm.embed_tokens(text_ids)
319
  input_img_inx = 0
320
- for b_inx in image_sizes.keys():
321
- for start_inx, end_inx in image_sizes[b_inx]:
322
  condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
323
  input_img_inx += 1
324
- if pixel_values is not None:
325
  assert input_img_inx == len(input_latents)
326
 
327
  input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
@@ -344,7 +347,9 @@ class OmniGen(nn.Module):
344
  x = self.final_layer(image_embedding, time_emb)
345
  latents = self.unpatchify(x, shapes[0], shapes[1])
346
 
347
- return latents, past_key_values
 
 
348
 
349
  @torch.no_grad()
350
  def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache):
 
5
  import numpy as np
6
  import math
7
  from typing import Dict
8
+
9
+ from diffusers.loaders import PeftAdapterMixin
10
  from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
11
+ from huggingface_hub import snapshot_download
12
 
13
  from OmniGen.transformer import Phi3Config, Phi3Transformer
14
 
 
148
  return x
149
 
150
 
151
+ class OmniGen(nn.Module, PeftAdapterMixin):
152
  """
153
  Diffusion model with a Transformer backbone.
154
  """
 
194
  ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
195
  config = Phi3Config.from_pretrained(model_name)
196
  model = cls(config)
197
+ ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
198
  model.load_state_dict(ckpt)
199
  return model
200
 
 
307
  return latents, num_tokens, shapes
308
 
309
 
310
+ def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True):
311
  """
312
 
313
  """
 
315
  x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
316
  time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
317
 
318
+ if input_img_latents is not None:
319
+ input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
320
+ if input_ids is not None:
321
+ condition_embeds = self.llm.embed_tokens(input_ids).clone()
322
  input_img_inx = 0
323
+ for b_inx in input_image_sizes.keys():
324
+ for start_inx, end_inx in input_image_sizes[b_inx]:
325
  condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
326
  input_img_inx += 1
327
+ if input_img_latents is not None:
328
  assert input_img_inx == len(input_latents)
329
 
330
  input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
 
347
  x = self.final_layer(image_embedding, time_emb)
348
  latents = self.unpatchify(x, shapes[0], shapes[1])
349
 
350
+ if past_key_values:
351
+ return latents, past_key_values
352
+ return latents
353
 
354
  @torch.no_grad()
355
  def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache):
OmniGen/pipeline.py CHANGED
@@ -6,6 +6,7 @@ from PIL import Image
6
  import numpy as np
7
  import torch
8
  from huggingface_hub import snapshot_download
 
9
  from diffusers.models import AutoencoderKL
10
  from diffusers.utils import (
11
  USE_PEFT_BACKEND,
@@ -31,7 +32,7 @@ EXAMPLE_DOC_STRING = """
31
  >>> prompt = "A woman holds a bouquet of flowers and faces the camera"
32
  >>> image = pipe(
33
  ... prompt,
34
- ... guidance_scale=1.0,
35
  ... num_inference_steps=50,
36
  ... ).images[0]
37
  >>> image.save("t2i.png")
@@ -53,23 +54,42 @@ class OmniGenPipeline:
53
 
54
  self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
55
  self.model.to(self.device)
 
56
  self.vae.to(self.device)
57
 
58
  @classmethod
59
- def from_pretrained(cls, model_name):
60
  if not os.path.exists(model_name):
 
61
  cache_folder = os.getenv('HF_HUB_CACHE')
62
- print(cache_folder)
63
  model_name = snapshot_download(repo_id=model_name,
64
  cache_dir=cache_folder,
65
  ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
66
  logger.info(f"Downloaded model to {model_name}")
67
  model = OmniGen.from_pretrained(model_name)
68
  processor = OmniGenProcessor.from_pretrained(model_name)
69
- vae = AutoencoderKL.from_pretrained(os.path.join(model_name, "vae"))
 
 
 
 
 
 
 
70
 
71
  return cls(vae, model, processor)
72
 
 
 
 
 
 
 
 
 
 
 
 
73
  def vae_encode(self, x, dtype):
74
  if self.vae.config.shift_factor is not None:
75
  x = self.vae.encode(x).latent_dist.sample()
@@ -100,6 +120,7 @@ class OmniGenPipeline:
100
  separate_cfg_infer: bool = False,
101
  use_kv_cache: bool = True,
102
  dtype: torch.dtype = torch.bfloat16,
 
103
  ):
104
  r"""
105
  Function invoked when calling the pipeline for generation.
@@ -128,15 +149,18 @@ class OmniGenPipeline:
128
  separate_cfg_infer (`bool`, *optional*, defaults to False):
129
  Perform inference on images with different guidance separately; this can save memory when generating images of large size at the expense of slower inference.
130
  use_kv_cache (`bool`, *optional*, defaults to True): enable kv cache to speed up the inference
131
-
 
 
132
  Examples:
133
 
134
  Returns:
135
  A list with the generated images.
136
  """
137
  assert height%16 == 0 and width%16 == 0
138
- if use_kv_cache and separate_cfg_infer:
139
- raise "Currently, don't support both use_kv_cache and separate_cfg_infer"
 
140
  if input_images is None:
141
  use_img_guidance = False
142
  if isinstance(prompt, str):
@@ -149,7 +173,11 @@ class OmniGenPipeline:
149
  num_cfg = 2 if use_img_guidance else 1
150
  latent_size_h, latent_size_w = height//8, width//8
151
 
152
- latents = torch.randn(num_prompt, 4, latent_size_h, latent_size_w, device=self.device)
 
 
 
 
153
  latents = torch.cat([latents]*(1+num_cfg), 0).to(dtype)
154
 
155
  input_img_latents = []
 
6
  import numpy as np
7
  import torch
8
  from huggingface_hub import snapshot_download
9
+ from peft import LoraConfig, PeftModel
10
  from diffusers.models import AutoencoderKL
11
  from diffusers.utils import (
12
  USE_PEFT_BACKEND,
 
32
  >>> prompt = "A woman holds a bouquet of flowers and faces the camera"
33
  >>> image = pipe(
34
  ... prompt,
35
+ ... guidance_scale=3.0,
36
  ... num_inference_steps=50,
37
  ... ).images[0]
38
  >>> image.save("t2i.png")
 
54
 
55
  self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
56
  self.model.to(self.device)
57
+ self.model.eval()
58
  self.vae.to(self.device)
59
 
60
  @classmethod
61
+ def from_pretrained(cls, model_name, vae_path: str=None):
62
  if not os.path.exists(model_name):
63
+ logger.info("Model not found, downloading...")
64
  cache_folder = os.getenv('HF_HUB_CACHE')
 
65
  model_name = snapshot_download(repo_id=model_name,
66
  cache_dir=cache_folder,
67
  ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
68
  logger.info(f"Downloaded model to {model_name}")
69
  model = OmniGen.from_pretrained(model_name)
70
  processor = OmniGenProcessor.from_pretrained(model_name)
71
+
72
+ if os.path.exists(os.path.join(model_name, "vae")):
73
+ vae = AutoencoderKL.from_pretrained(os.path.join(model_name, "vae"))
74
+ elif vae_path is not None:
75
+ vae = AutoencoderKL.from_pretrained(vae_path).to(device)
76
+ else:
77
+ logger.info(f"No VAE found in {model_name}, downloading stabilityai/sdxl-vae from HF")
78
+ vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device)
79
 
80
  return cls(vae, model, processor)
81
 
82
+ def merge_lora(self, lora_path: str):
83
+ model = PeftModel.from_pretrained(self.model, lora_path)
84
+ model.merge_and_unload()
85
+ self.model = model
86
+
87
+ def to(self, device: Union[str, torch.device]):
88
+ if isinstance(device, str):
89
+ device = torch.device(device)
90
+ self.model.to(device)
91
+ self.vae.to(device)
92
+
93
  def vae_encode(self, x, dtype):
94
  if self.vae.config.shift_factor is not None:
95
  x = self.vae.encode(x).latent_dist.sample()
 
120
  separate_cfg_infer: bool = False,
121
  use_kv_cache: bool = True,
122
  dtype: torch.dtype = torch.bfloat16,
123
+ seed: int = None,
124
  ):
125
  r"""
126
  Function invoked when calling the pipeline for generation.
 
149
  separate_cfg_infer (`bool`, *optional*, defaults to False):
150
  Perform inference on images with different guidance separately; this can save memory when generating images of large size at the expense of slower inference.
151
  use_kv_cache (`bool`, *optional*, defaults to True): enable kv cache to speed up the inference
152
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
153
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
154
+ to make generation deterministic.
155
  Examples:
156
 
157
  Returns:
158
  A list with the generated images.
159
  """
160
  assert height%16 == 0 and width%16 == 0
161
+ if separate_cfg_infer:
162
+ use_kv_cache = False
163
+ # raise "Currently, don't support both use_kv_cache and separate_cfg_infer"
164
  if input_images is None:
165
  use_img_guidance = False
166
  if isinstance(prompt, str):
 
173
  num_cfg = 2 if use_img_guidance else 1
174
  latent_size_h, latent_size_w = height//8, width//8
175
 
176
+ if seed is not None:
177
+ generator = torch.Generator(device=self.device).manual_seed(seed)
178
+ else:
179
+ generator = None
180
+ latents = torch.randn(num_prompt, 4, latent_size_h, latent_size_w, device=self.device, generator=generator)
181
  latents = torch.cat([latents]*(1+num_cfg), 0).to(dtype)
182
 
183
  input_img_latents = []
OmniGen/processor.py CHANGED
@@ -11,28 +11,15 @@ from torchvision import transforms
11
  from transformers import AutoTokenizer
12
  from huggingface_hub import snapshot_download
13
 
 
 
 
 
 
 
 
14
 
15
- def crop_arr(pil_image, max_image_size):
16
- while min(*pil_image.size) >= 2 * max_image_size:
17
- pil_image = pil_image.resize(
18
- tuple(x // 2 for x in pil_image.size), resample=Image.BOX
19
- )
20
 
21
- if max(*pil_image.size) > max_image_size:
22
- scale = max_image_size / max(*pil_image.size)
23
- pil_image = pil_image.resize(
24
- tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
25
- )
26
-
27
- arr = np.array(pil_image)
28
- crop_y1 = (arr.shape[0] % 16) // 2
29
- crop_y2 = arr.shape[0] % 16 - crop_y1
30
-
31
- crop_x1 = (arr.shape[1] % 16) // 2
32
- crop_x2 = arr.shape[1] % 16 - crop_x1
33
-
34
- arr = arr[crop_y1:arr.shape[0]-crop_y2, crop_x1:arr.shape[1]-crop_x2]
35
- return Image.fromarray(arr)
36
 
37
 
38
  class OmniGenProcessor:
@@ -68,6 +55,7 @@ class OmniGenProcessor:
68
  return self.image_transform(image)
69
 
70
  def process_multi_modal_prompt(self, text, input_images):
 
71
  if input_images is None or len(input_images) == 0:
72
  model_inputs = self.text_tokenizer(text)
73
  return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}
@@ -132,7 +120,6 @@ class OmniGenProcessor:
132
  for i in range(len(instructions)):
133
  cur_instruction = instructions[i]
134
  cur_input_images = None if input_images is None else input_images[i]
135
- cur_instruction = self.add_prefix_instruction(cur_instruction)
136
  if cur_input_images is not None and len(cur_input_images) > 0:
137
  cur_input_images = [self.process_image(x) for x in cur_input_images]
138
  else:
@@ -143,14 +130,13 @@ class OmniGenProcessor:
143
 
144
 
145
  neg_mllm_input, img_cfg_mllm_input = None, None
146
- neg_instruction = self.add_prefix_instruction(negative_prompt)
147
- neg_mllm_input = self.process_multi_modal_prompt(neg_instruction, None)
148
  if use_img_cfg:
149
  if cur_input_images is not None and len(cur_input_images) >= 1:
150
  img_cfg_prompt = [f"<img><|image_{i+1}|></img>" for i in range(len(cur_input_images))]
151
- img_cfg_mllm_input = self.process_multi_modal_prompt(self.add_prefix_instruction(" ".join(img_cfg_prompt)), cur_input_images)
152
  else:
153
- img_cfg_mllm_input = neg_instruction
154
 
155
  input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
156
 
 
11
  from transformers import AutoTokenizer
12
  from huggingface_hub import snapshot_download
13
 
14
+ from OmniGen.utils import (
15
+ create_logger,
16
+ update_ema,
17
+ requires_grad,
18
+ center_crop_arr,
19
+ crop_arr,
20
+ )
21
 
 
 
 
 
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  class OmniGenProcessor:
 
55
  return self.image_transform(image)
56
 
57
  def process_multi_modal_prompt(self, text, input_images):
58
+ text = self.add_prefix_instruction(text)
59
  if input_images is None or len(input_images) == 0:
60
  model_inputs = self.text_tokenizer(text)
61
  return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}
 
120
  for i in range(len(instructions)):
121
  cur_instruction = instructions[i]
122
  cur_input_images = None if input_images is None else input_images[i]
 
123
  if cur_input_images is not None and len(cur_input_images) > 0:
124
  cur_input_images = [self.process_image(x) for x in cur_input_images]
125
  else:
 
130
 
131
 
132
  neg_mllm_input, img_cfg_mllm_input = None, None
133
+ neg_mllm_input = self.process_multi_modal_prompt(negative_prompt, None)
 
134
  if use_img_cfg:
135
  if cur_input_images is not None and len(cur_input_images) >= 1:
136
  img_cfg_prompt = [f"<img><|image_{i+1}|></img>" for i in range(len(cur_input_images))]
137
+ img_cfg_mllm_input = self.process_multi_modal_prompt(" ".join(img_cfg_prompt), cur_input_images)
138
  else:
139
+ img_cfg_mllm_input = neg_mllm_input
140
 
141
  input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
142
 
OmniGen/scheduler.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  from tqdm import tqdm
3
- from transformers.cache_utils import Cache, DynamicCache, OffloadedCache
4
 
5
  class OmniGenScheduler:
6
  def __init__(self, num_steps: int=50, time_shifting_factor: int=1):
 
1
  import torch
2
  from tqdm import tqdm
3
+ from transformers.cache_utils import Cache, DynamicCache
4
 
5
  class OmniGenScheduler:
6
  def __init__(self, num_steps: int=50, time_shifting_factor: int=1):
OmniGen/transformer.py CHANGED
@@ -16,7 +16,7 @@ from transformers.modeling_outputs import (
16
  )
17
  from transformers.modeling_utils import PreTrainedModel
18
  from transformers import Phi3Config, Phi3Model
19
- from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache, OffloadedCache
20
  from transformers.utils import logging
21
 
22
  logger = logging.get_logger(__name__)
 
16
  )
17
  from transformers.modeling_utils import PreTrainedModel
18
  from transformers import Phi3Config, Phi3Model
19
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
20
  from transformers.utils import logging
21
 
22
  logger = logging.get_logger(__name__)
OmniGen/utils.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from PIL import Image
4
+ import torch
5
+ import numpy as np
6
+
7
+ def create_logger(logging_dir):
8
+ """
9
+ Create a logger that writes to a log file and stdout.
10
+ """
11
+ logging.basicConfig(
12
+ level=logging.INFO,
13
+ format='[\033[34m%(asctime)s\033[0m] %(message)s',
14
+ datefmt='%Y-%m-%d %H:%M:%S',
15
+ handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
16
+ )
17
+ logger = logging.getLogger(__name__)
18
+ return logger
19
+
20
+
21
+ @torch.no_grad()
22
+ def update_ema(ema_model, model, decay=0.9999):
23
+ """
24
+ Step the EMA model towards the current model.
25
+ """
26
+ ema_params = dict(ema_model.named_parameters())
27
+ for name, param in model.named_parameters():
28
+ # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
29
+ ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
30
+
31
+
32
+
33
+
34
+ def requires_grad(model, flag=True):
35
+ """
36
+ Set requires_grad flag for all parameters in a model.
37
+ """
38
+ for p in model.parameters():
39
+ p.requires_grad = flag
40
+
41
+
42
+ def center_crop_arr(pil_image, image_size):
43
+ """
44
+ Center cropping implementation from ADM.
45
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
46
+ """
47
+ while min(*pil_image.size) >= 2 * image_size:
48
+ pil_image = pil_image.resize(
49
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
50
+ )
51
+
52
+ scale = image_size / min(*pil_image.size)
53
+ pil_image = pil_image.resize(
54
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
55
+ )
56
+
57
+ arr = np.array(pil_image)
58
+ crop_y = (arr.shape[0] - image_size) // 2
59
+ crop_x = (arr.shape[1] - image_size) // 2
60
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
61
+
62
+
63
+
64
+ def crop_arr(pil_image, max_image_size):
65
+ while min(*pil_image.size) >= 2 * max_image_size:
66
+ pil_image = pil_image.resize(
67
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
68
+ )
69
+
70
+ if max(*pil_image.size) > max_image_size:
71
+ scale = max_image_size / max(*pil_image.size)
72
+ pil_image = pil_image.resize(
73
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
74
+ )
75
+
76
+ if min(*pil_image.size) < 16:
77
+ scale = 16 / min(*pil_image.size)
78
+ pil_image = pil_image.resize(
79
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
80
+ )
81
+
82
+ arr = np.array(pil_image)
83
+ crop_y1 = (arr.shape[0] % 16) // 2
84
+ crop_y2 = arr.shape[0] % 16 - crop_y1
85
+
86
+ crop_x1 = (arr.shape[1] % 16) // 2
87
+ crop_x2 = arr.shape[1] % 16 - crop_x1
88
+
89
+ arr = arr[crop_y1:arr.shape[0]-crop_y2, crop_x1:arr.shape[1]-crop_x2]
90
+ return Image.fromarray(arr)
91
+
92
+
93
+
94
+ def vae_encode(vae, x, weight_dtype):
95
+ if x is not None:
96
+ if vae.config.shift_factor is not None:
97
+ x = vae.encode(x).latent_dist.sample()
98
+ x = (x - vae.config.shift_factor) * vae.config.scaling_factor
99
+ else:
100
+ x = vae.encode(x).latent_dist.sample().mul_(vae.config.scaling_factor)
101
+ x = x.to(weight_dtype)
102
+ return x
103
+
104
+ def vae_encode_list(vae, x, weight_dtype):
105
+ latents = []
106
+ for img in x:
107
+ img = vae_encode(vae, img, weight_dtype)
108
+ latents.append(img)
109
+ return latents
110
+
app.py CHANGED
@@ -11,7 +11,7 @@ pipe = OmniGenPipeline.from_pretrained(
11
 
12
  @spaces.GPU
13
  # 示例处理函数:生成图像
14
- def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps):
15
  input_images = [img1, img2, img3]
16
  # 去除 None
17
  input_images = [img for img in input_images if img is not None]
@@ -28,6 +28,7 @@ def generate_image(text, img1, img2, img3, height, width, guidance_scale, infere
28
  num_inference_steps=inference_steps,
29
  separate_cfg_infer=True,
30
  use_kv_cache=False,
 
31
  )
32
  img = output[0]
33
  return img
@@ -54,6 +55,7 @@ def get_example():
54
  1024,
55
  3.0,
56
  20,
 
57
  ],
58
  [
59
  "Three zebras are standing side by side on a vibrant savannah, each showcasing unique patterns and characteristics that highlight their individuality. The zebra on the left has a strikingly bold black and white stripe pattern, with wider stripes that create a dramatic contrast against its sleek body. In the middle, the zebra features a more subtle stripe arrangement, with thinner stripes that blend seamlessly into a slightly sandy-colored coat, giving it a softer appearance. On the right, the zebra's stripes are more irregular, with a distinct patch of brown fur near its shoulder, adding a layer of uniqueness to its overall look. Together, these zebras create a captivating scene, each representing the diverse beauty of their species in the wild. The right zebras is the zebras from <img><|image_1|></img>. The center zebras is from <img><|image_2|></img>. The left zebras is the zebras from <img><|image_3|></img>.",
@@ -64,22 +66,23 @@ def get_example():
64
  1024,
65
  3.0,
66
  20,
 
67
  ],
68
  ]
69
  return case
70
 
71
- def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, inference_steps):
72
- return generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps)
73
 
74
 
75
  # Gradio 接口
76
  with gr.Blocks() as demo:
77
- gr.Markdown("## Text + Multiple Images to Image Generator")
78
  with gr.Row():
79
  with gr.Column():
80
  # 文本输入框
81
  prompt_input = gr.Textbox(
82
- label="Enter your prompt", placeholder="Type your prompt here..."
83
  )
84
 
85
  with gr.Row(equal_height=True):
@@ -105,6 +108,10 @@ with gr.Blocks() as demo:
105
  label="Inference Steps", minimum=1, maximum=50, value=50, step=1
106
  )
107
 
 
 
 
 
108
  # 生成按钮
109
  generate_button = gr.Button("Generate Image")
110
 
@@ -124,6 +131,7 @@ with gr.Blocks() as demo:
124
  width_input,
125
  guidance_scale_input,
126
  num_inference_steps,
 
127
  ],
128
  outputs=output_image,
129
  )
@@ -140,6 +148,7 @@ with gr.Blocks() as demo:
140
  width_input,
141
  guidance_scale_input,
142
  num_inference_steps,
 
143
  ],
144
  outputs=output_image,
145
  )
 
11
 
12
  @spaces.GPU
13
  # 示例处理函数:生成图像
14
+ def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed):
15
  input_images = [img1, img2, img3]
16
  # 去除 None
17
  input_images = [img for img in input_images if img is not None]
 
28
  num_inference_steps=inference_steps,
29
  separate_cfg_infer=True,
30
  use_kv_cache=False,
31
+ seed=seed,
32
  )
33
  img = output[0]
34
  return img
 
55
  1024,
56
  3.0,
57
  20,
58
+ 42,
59
  ],
60
  [
61
  "Three zebras are standing side by side on a vibrant savannah, each showcasing unique patterns and characteristics that highlight their individuality. The zebra on the left has a strikingly bold black and white stripe pattern, with wider stripes that create a dramatic contrast against its sleek body. In the middle, the zebra features a more subtle stripe arrangement, with thinner stripes that blend seamlessly into a slightly sandy-colored coat, giving it a softer appearance. On the right, the zebra's stripes are more irregular, with a distinct patch of brown fur near its shoulder, adding a layer of uniqueness to its overall look. Together, these zebras create a captivating scene, each representing the diverse beauty of their species in the wild. The right zebras is the zebras from <img><|image_1|></img>. The center zebras is from <img><|image_2|></img>. The left zebras is the zebras from <img><|image_3|></img>.",
 
66
  1024,
67
  3.0,
68
  20,
69
+ 42,
70
  ],
71
  ]
72
  return case
73
 
74
+ def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed):
75
+ return generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed)
76
 
77
 
78
  # Gradio 接口
79
  with gr.Blocks() as demo:
80
+ gr.Markdown("# OmniGen: Unified Image Generation")
81
  with gr.Row():
82
  with gr.Column():
83
  # 文本输入框
84
  prompt_input = gr.Textbox(
85
+ label="Enter your prompt, use <img><|image_i|></img> tokens for images", placeholder="Type your prompt here..."
86
  )
87
 
88
  with gr.Row(equal_height=True):
 
108
  label="Inference Steps", minimum=1, maximum=50, value=50, step=1
109
  )
110
 
111
+ seed_input = gr.Slider(
112
+ label="Seed", minimum=0, maximum=2147483647, value=42, step=1
113
+ )
114
+
115
  # 生成按钮
116
  generate_button = gr.Button("Generate Image")
117
 
 
131
  width_input,
132
  guidance_scale_input,
133
  num_inference_steps,
134
+ seed_input,
135
  ],
136
  outputs=output_image,
137
  )
 
148
  width_input,
149
  guidance_scale_input,
150
  num_inference_steps,
151
+ seed_input,
152
  ],
153
  outputs=output_image,
154
  )