File size: 2,455 Bytes
9a13713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
import spaces
import torch
import diffusers

from generation import generate_material


@spaces.GPU
def generate(prompts, seed, resolution, refinement):
  image = generate_material(prompts, seed=seed, resolution=int(resolution), refinement=refinement)
  return image.basecolor, image.normal, image.height, image.metallic, image.roughness


def interface_function(prompt_type, text_prompt, image_prompt, seed, resolution, refinement):
    if prompt_type == "Text":
        return generate(text_prompt, seed, resolution, refinement)
    elif prompt_type == "Image":
        return generate(image_prompt, seed, resolution, refinement)

def update_visibility(prompt_type):
    if prompt_type == "Text":
        return gr.update(visible=True), gr.update(visible=False)
    elif prompt_type == "Image":
        return gr.update(visible=False), gr.update(visible=True)

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            prompt_type = gr.Radio(choices=["Text", "Image"], label="Prompt Type", value="Text")
            text_prompt = gr.Textbox(label="Text Prompt", visible=True, lines=3, placeholder="A brick wall")
            image_prompt = gr.Image(type="pil", label="Image Prompt", visible=False)
        
        with gr.Column():
            seed = gr.Number(value=-1, label="Seed (-1 for random)")
            resolution = gr.Dropdown(["512", "1024", "2048"], value="512", label="Resolution", interactive=False)
            refinement = gr.Checkbox(label="Refinement", interactive=False)
            generate_button = gr.Button("Generate")

    prompt_type.change(fn=update_visibility, inputs=prompt_type, outputs=[text_prompt, image_prompt])
    
    with gr.Row():
        output_basecolor = gr.Image(label="Base Color", format="png", image_mode="RGB")
        output_normal = gr.Image(label="Normal Map", format="png", image_mode="RGB")
        output_height = gr.Image(label="Height Map", format="png", image_mode="L")
        output_metallic = gr.Image(label="Metallic Map", format="png", image_mode="L")
        output_roughness = gr.Image(label="Roughness Map", format="png", image_mode="L")
    
    generate_button.click(
        fn=interface_function, 
        inputs=[prompt_type, text_prompt, image_prompt, seed, resolution, refinement], 
        outputs=[output_basecolor, output_normal, output_height, output_metallic, output_roughness]
    )

if __name__ == "__main__":
    demo.launch()