import gradio as gr import spaces import torch import diffusers from generation import generate_material @spaces.GPU def generate(prompts, seed, resolution, refinement): image = generate_material(prompts, seed=seed, resolution=int(resolution), refinement=refinement) return image.basecolor, image.normal, image.height, image.metallic, image.roughness def interface_function(prompt_type, text_prompt, image_prompt, seed, resolution, refinement): if prompt_type == "Text": return generate(text_prompt, seed, resolution, refinement) elif prompt_type == "Image": return generate(image_prompt, seed, resolution, refinement) def update_visibility(prompt_type): if prompt_type == "Text": return gr.update(visible=True), gr.update(visible=False) elif prompt_type == "Image": return gr.update(visible=False), gr.update(visible=True) with gr.Blocks() as demo: with gr.Row(): with gr.Column(): prompt_type = gr.Radio(choices=["Text", "Image"], label="Prompt Type", value="Text") text_prompt = gr.Textbox(label="Text Prompt", visible=True, lines=3, placeholder="A brick wall") image_prompt = gr.Image(type="pil", label="Image Prompt", visible=False) with gr.Column(): seed = gr.Number(value=-1, label="Seed (-1 for random)") resolution = gr.Dropdown(["512", "1024", "2048"], value="512", label="Resolution", interactive=False) refinement = gr.Checkbox(label="Refinement", interactive=False) generate_button = gr.Button("Generate") prompt_type.change(fn=update_visibility, inputs=prompt_type, outputs=[text_prompt, image_prompt]) with gr.Row(): output_basecolor = gr.Image(label="Base Color", format="png", image_mode="RGB") output_normal = gr.Image(label="Normal Map", format="png", image_mode="RGB") output_height = gr.Image(label="Height Map", format="png", image_mode="L") output_metallic = gr.Image(label="Metallic Map", format="png", image_mode="L") output_roughness = gr.Image(label="Roughness Map", format="png", image_mode="L") generate_button.click( fn=interface_function, inputs=[prompt_type, text_prompt, image_prompt, seed, resolution, refinement], outputs=[output_basecolor, output_normal, output_height, output_metallic, output_roughness] ) if __name__ == "__main__": demo.launch()