niulx commited on
Commit
5c88297
·
verified ·
1 Parent(s): cc94ad7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -113
app.py CHANGED
@@ -188,132 +188,132 @@ with gr.Blocks() as demo:
188
  with gr.Row():
189
  gr.Markdown("""# D-Edit""")
190
 
191
- if 1:
192
- with gr.Row():
193
- with gr.Column():
194
- canvas = gr.Image(value = None, type="numpy", label="Show Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
195
- example_inps = [['./img.png'],['./img2.png'],['./img3.png'],['./img4.png']]
196
- gr.Examples(examples=example_inps, inputs=[canvas],
197
- label='examples', cache_examples='lazy', outputs=[],
198
- fn=change_image)
199
- gr.Markdown(f"Each image must first undergo segmentation. Afterwards, you can modify the \n mask ID and the prompt for image editing, then proceed with the editing process. \n The link of D-edit paper: [https://arxiv.org/abs/2403.04880v2](https://arxiv.org/abs/2403.04880v2), [https://huggingface.co/papers/2403.04880](https://huggingface.co/papers/2403.04880)")
200
 
201
- with gr.Column():
202
- result_info0 = gr.Text(label="Response")
203
- segment_button = gr.Button("Step 1. Run segmentation")
204
- flag = gr.State(False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
- # mask_np_list_updated.value = copy.deepcopy(mask_np_list.value) #!!
207
- mask_np_list_updated = mask_np_list
208
- gr.Markdown("""<p style="text-align: center; font-size: 20px">Edit Mask (Do not change it during the editing process)</p>""")
209
- slider = gr.Slider(0, 20, step=1, label = 'mask id', visible=False)
210
- label = gr.Text(label='label')
211
 
 
212
 
 
 
 
 
 
 
 
213
 
214
-
215
- result_info = gr.Text(label="Response")
216
 
217
- opt_flag = gr.State(0)
218
- gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings</p>""")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  with gr.Accordion(label="Advanced settings", open=False):
220
- num_tokens = gr.Number(value="5", label="num tokens to represent each object", interactive= True)
221
- num_tokens_global = num_tokens
222
- embedding_learning_rate = gr.Textbox(value="0.00025", label="Embedding optimization: Learning rate", interactive= True )
223
- max_emb_train_steps = gr.Number(value="6", label="embedding optimization: Training steps", interactive= True )
224
-
225
- diffusion_model_learning_rate = gr.Textbox(value="0.0002", label="UNet Optimization: Learning rate", interactive= True )
226
- max_diffusion_train_steps = gr.Number(value="28", label="UNet Optimization: Learning rate: Training steps", interactive= True )
227
-
228
- train_batch_size = gr.Number(value="20", label="Batch size", interactive= True )
229
- gradient_accumulation_steps=gr.Number(value="2", label="Gradient accumulation", interactive= True )
230
 
231
- def run_optimization_wrapper (
 
232
  mask_np_list,
233
  mask_label_list,
234
  image,
235
- opt_flag,
236
  num_tokens,
237
- embedding_learning_rate ,
238
- max_emb_train_steps ,
239
- diffusion_model_learning_rate ,
240
- max_diffusion_train_steps,
241
- train_batch_size,
242
- gradient_accumulation_steps,
243
  ):
244
- try:
245
- run_optimization = partial(
246
- run_main,
247
- mask_np_list=mask_np_list,
248
- mask_label_list=mask_label_list,
249
- image_gt=np.array(image),
250
- num_tokens=int(num_tokens),
251
- embedding_learning_rate = float(embedding_learning_rate),
252
- max_emb_train_steps = int(max_emb_train_steps),
253
- diffusion_model_learning_rate= float(diffusion_model_learning_rate),
254
- max_diffusion_train_steps = int(max_diffusion_train_steps),
255
- train_batch_size=int(train_batch_size),
256
- gradient_accumulation_steps=int(gradient_accumulation_steps)
257
- )
258
- run_optimization()
259
- gr.Info("Optimization Finished! Move to the next step.")
260
- return "Optimization finished! Move to the next step."#,gr.Button("Step 3. Run Editing",interactive = True)
261
- except Exception as e:
262
- print(e)
263
- gr.Error("e")
264
- return "Error: use a smaller batch size or try latter."#,gr.Button("Step 3. Run Editing",interactive = False)
265
-
266
-
267
-
268
-
269
- with gr.Row():
270
- with gr.Column():
271
- canvas_text_edit = gr.Image(value = None, type = "pil", label="Editing results", show_label=True,visible = True)
272
- # canvas_text_edit = gr.Gallery(label = "Edited results")
273
-
274
- with gr.Column():
275
- gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing setting</p>""")
276
- tgt_prompt = gr.Textbox(value="text prompt", label="Editing: Text prompt", interactive= True )
277
- with gr.Accordion(label="Advanced settings", open=False):
278
- slider2 = gr.Slider(0, 20, step=1, label = 'mask id', visible=False)
279
- guidance_scale = gr.Textbox(value="5", label="Editing: CFG guidance scale", interactive= True )
280
- num_sampling_steps = gr.Number(value="20", label="Editing: Sampling steps", interactive= True )
281
- edge_thickness = gr.Number(value="10", label="Editing: Edge thickness", interactive= True )
282
- strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True )
283
-
284
- add_button = gr.Button("Step 2. Run Editing",interactive = True)
285
- def run_edit_text_wrapper(
286
- mask_np_list,
287
- mask_label_list,
288
- image,
289
- num_tokens,
290
- guidance_scale,
291
- num_sampling_steps ,
292
- strength ,
293
- edge_thickness,
294
- tgt_prompt ,
295
- tgt_index
296
- ):
297
-
298
- run_edit_text = partial(
299
- run_main,
300
- mask_np_list=mask_np_list,
301
- mask_label_list=mask_label_list,
302
- image_gt=np.array(image),
303
- load_trained=True,
304
- text=True,
305
- num_tokens = int(num_tokens_global.value),
306
- guidance_scale = float(guidance_scale),
307
- num_sampling_steps = int(num_sampling_steps),
308
- strength = float(strength),
309
- edge_thickness = int(edge_thickness),
310
- num_imgs = 1,
311
- tgt_prompt = tgt_prompt,
312
- tgt_index = int(tgt_index)
313
- )
314
- run_edit_text()
315
- gr.Info('Image editing completed.')
316
- return load_pil_img()
317
 
318
 
319
 
 
188
  with gr.Row():
189
  gr.Markdown("""# D-Edit""")
190
 
 
 
 
 
 
 
 
 
 
191
 
192
+ with gr.Row():
193
+ with gr.Column():
194
+ canvas = gr.Image(value = None, type="numpy", label="Show Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
195
+ example_inps = [['./img.png'],['./img2.png'],['./img3.png'],['./img4.png']]
196
+ gr.Examples(examples=example_inps, inputs=[canvas],
197
+ label='examples', cache_examples='lazy', outputs=[],
198
+ fn=change_image)
199
+ gr.Markdown(f"Each image must first undergo segmentation. Afterwards, you can modify the \n mask ID and the prompt for image editing, then proceed with the editing process. \n The link of D-edit paper: [https://arxiv.org/abs/2403.04880v2](https://arxiv.org/abs/2403.04880v2), [https://huggingface.co/papers/2403.04880](https://huggingface.co/papers/2403.04880)")
200
+
201
+ with gr.Column():
202
+ result_info0 = gr.Text(label="Response")
203
+ segment_button = gr.Button("Step 1. Run segmentation")
204
+ flag = gr.State(False)
205
+
206
+ # mask_np_list_updated.value = copy.deepcopy(mask_np_list.value) #!!
207
+ mask_np_list_updated = mask_np_list
208
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Edit Mask (Do not change it during the editing process)</p>""")
209
+ slider = gr.Slider(0, 20, step=1, label = 'mask id', visible=False)
210
+ label = gr.Text(label='label')
211
 
212
+
213
+
 
 
 
214
 
215
+ result_info = gr.Text(label="Response")
216
 
217
+ opt_flag = gr.State(0)
218
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings</p>""")
219
+ with gr.Accordion(label="Advanced settings", open=False):
220
+ num_tokens = gr.Number(value="5", label="num tokens to represent each object", interactive= True)
221
+ num_tokens_global = num_tokens
222
+ embedding_learning_rate = gr.Textbox(value="0.00025", label="Embedding optimization: Learning rate", interactive= True )
223
+ max_emb_train_steps = gr.Number(value="6", label="embedding optimization: Training steps", interactive= True )
224
 
225
+ diffusion_model_learning_rate = gr.Textbox(value="0.0002", label="UNet Optimization: Learning rate", interactive= True )
226
+ max_diffusion_train_steps = gr.Number(value="28", label="UNet Optimization: Learning rate: Training steps", interactive= True )
227
 
228
+ train_batch_size = gr.Number(value="20", label="Batch size", interactive= True )
229
+ gradient_accumulation_steps=gr.Number(value="2", label="Gradient accumulation", interactive= True )
230
+
231
+ def run_optimization_wrapper (
232
+ mask_np_list,
233
+ mask_label_list,
234
+ image,
235
+ opt_flag,
236
+ num_tokens,
237
+ embedding_learning_rate ,
238
+ max_emb_train_steps ,
239
+ diffusion_model_learning_rate ,
240
+ max_diffusion_train_steps,
241
+ train_batch_size,
242
+ gradient_accumulation_steps,
243
+ ):
244
+ try:
245
+ run_optimization = partial(
246
+ run_main,
247
+ mask_np_list=mask_np_list,
248
+ mask_label_list=mask_label_list,
249
+ image_gt=np.array(image),
250
+ num_tokens=int(num_tokens),
251
+ embedding_learning_rate = float(embedding_learning_rate),
252
+ max_emb_train_steps = int(max_emb_train_steps),
253
+ diffusion_model_learning_rate= float(diffusion_model_learning_rate),
254
+ max_diffusion_train_steps = int(max_diffusion_train_steps),
255
+ train_batch_size=int(train_batch_size),
256
+ gradient_accumulation_steps=int(gradient_accumulation_steps)
257
+ )
258
+ run_optimization()
259
+ gr.Info("Optimization Finished! Move to the next step.")
260
+ return "Optimization finished! Move to the next step."#,gr.Button("Step 3. Run Editing",interactive = True)
261
+ except Exception as e:
262
+ print(e)
263
+ gr.Error("e")
264
+ return "Error: use a smaller batch size or try latter."#,gr.Button("Step 3. Run Editing",interactive = False)
265
+
266
+
267
+
268
+ if 1:
269
+ with gr.Row():
270
+ with gr.Column():
271
+ canvas_text_edit = gr.Image(value = None, type = "pil", label="Editing results", show_label=True,visible = True)
272
+ # canvas_text_edit = gr.Gallery(label = "Edited results")
273
+
274
+ with gr.Column():
275
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing setting</p>""")
276
+ tgt_prompt = gr.Textbox(value="text prompt", label="Editing: Text prompt", interactive= True )
277
  with gr.Accordion(label="Advanced settings", open=False):
278
+ slider2 = gr.Slider(0, 20, step=1, label = 'mask id', visible=False)
279
+ guidance_scale = gr.Textbox(value="5", label="Editing: CFG guidance scale", interactive= True )
280
+ num_sampling_steps = gr.Number(value="20", label="Editing: Sampling steps", interactive= True )
281
+ edge_thickness = gr.Number(value="10", label="Editing: Edge thickness", interactive= True )
282
+ strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True )
 
 
 
 
 
283
 
284
+ add_button = gr.Button("Step 2. Run Editing",interactive = True)
285
+ def run_edit_text_wrapper(
286
  mask_np_list,
287
  mask_label_list,
288
  image,
 
289
  num_tokens,
290
+ guidance_scale,
291
+ num_sampling_steps ,
292
+ strength ,
293
+ edge_thickness,
294
+ tgt_prompt ,
295
+ tgt_index
296
  ):
297
+
298
+ run_edit_text = partial(
299
+ run_main,
300
+ mask_np_list=mask_np_list,
301
+ mask_label_list=mask_label_list,
302
+ image_gt=np.array(image),
303
+ load_trained=True,
304
+ text=True,
305
+ num_tokens = int(num_tokens_global.value),
306
+ guidance_scale = float(guidance_scale),
307
+ num_sampling_steps = int(num_sampling_steps),
308
+ strength = float(strength),
309
+ edge_thickness = int(edge_thickness),
310
+ num_imgs = 1,
311
+ tgt_prompt = tgt_prompt,
312
+ tgt_index = int(tgt_index)
313
+ )
314
+ run_edit_text()
315
+ gr.Info('Image editing completed.')
316
+ return load_pil_img()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
 
319