Srikumar26 commited on
Commit
deb128d
1 Parent(s): 98737ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -4
app.py CHANGED
@@ -1,6 +1,44 @@
1
- import gradio as gr
2
- import os
3
 
4
- hf_token = os.environ.get("hf_token")
 
 
5
 
6
- gr.load("models/MVRL/SD-GeoSynth", hf_token=hf_token).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableDiffusionPipeline, DDIMScheduler
3
 
4
+ pipe = StableDiffusionPipeline.from_pretrained("MVRL/GeoSynth")
5
+ scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base")
6
+ pipe.scheduler = scheduler
7
 
8
+ def process(prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta):
9
+ generator = torch.manual_seed(seed)
10
+ output_images = pipe(prompt,
11
+ height=image_resolution,
12
+ width=image_resolution,
13
+ num_inference_steps=ddim_steps,
14
+ guidance_scale=scale,
15
+ negative_prompt=n_prompt,
16
+ num_images_per_prompt=num_samples,
17
+ eta=eta,
18
+ generator=generator,
19
+ ).images
20
+ return output_images
21
+
22
+ block = gr.Blocks().queue()
23
+ with block:
24
+ with gr.Row():
25
+ gr.Markdown("## Control Stable Diffusion with Depth Maps")
26
+ with gr.Row():
27
+ with gr.Column():
28
+ prompt = gr.Textbox(label="Prompt")
29
+ run_button = gr.Button(label="Run")
30
+ with gr.Accordion("Advanced options", open=True):
31
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
32
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
33
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
34
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
35
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
36
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
37
+ n_prompt = gr.Textbox(label="Negative Prompt",
38
+ value='')
39
+ with gr.Column():
40
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(height='auto')
41
+ ips = [prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta]
42
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
43
+
44
+ block.launch()