gojiteji commited on
Commit
a8e3fd3
1 Parent(s): 384dde1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -37
app.py CHANGED
@@ -59,43 +59,6 @@ pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
59
  dtype=jnp.bfloat16,
60
  )
61
 
62
- prompts = ["apple"] * 1
63
-
64
-
65
-
66
-
67
-
68
-
69
-
70
-
71
-
72
- def generate_image(dense_class_vector=None, int_index=None, noise_seed_vector=None, truncation=0.4):
73
- seed = int(noise_seed_vector.sum().item()) if noise_seed_vector is not None else None
74
- noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1, seed=seed)
75
- noise_vector = torch.from_numpy(noise_vector)
76
- if int_index is not None:
77
- class_vector = one_hot_from_int([int_index], batch_size=1)
78
- class_vector = torch.from_numpy(class_vector)
79
- dense_class_vector = gan_model.embeddings(class_vector)
80
- else:
81
- if isinstance(dense_class_vector, np.ndarray):
82
- dense_class_vector = torch.tensor(dense_class_vector)
83
- dense_class_vector = dense_class_vector.view(1, 128)
84
-
85
- input_vector = torch.cat([noise_vector, dense_class_vector], dim=1)
86
-
87
- # Generate an image
88
- with torch.no_grad():
89
- output = gan_model.generator(input_vector, truncation)
90
- output = output.cpu().numpy()
91
- output = output.transpose((0, 2, 3, 1))
92
- output = ((output + 1.0) / 2.0) * 256
93
- output.clip(0, 255, out=output)
94
- output = np.asarray(np.uint8(output[0]), dtype=np.uint8)
95
- return output
96
-
97
-
98
-
99
  def text_to_image(text):
100
  images = sd2_inference(pipeline, [text], params, seed = 42, num_inference_steps = 5 )
101
  img = images[0]
 
59
  dtype=jnp.bfloat16,
60
  )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def text_to_image(text):
63
  images = sd2_inference(pipeline, [text], params, seed = 42, num_inference_steps = 5 )
64
  img = images[0]