import gradio as gr import PIL.Image from pathlib import Path import pandas as pd from diffusers.pipelines import StableDiffusionPipeline import torch import argparse import os import warnings from safetensors.torch import load_file import yaml warnings.filterwarnings("ignore") OUTPUT_DIR = "OUTPUT" cuda_device = 1 device = f"cuda:{cuda_device}" if torch.cuda.is_available() else "cpu" print("DEVICE: ", device) TITLE = "Demo for Generating Chest X-rays using Diferent Parameter-Efficient Fine-Tuned Stable Diffusion Pipelines" INFO_ABOUT_TEXT_PROMPT = "INFO_ABOUT_TEXT_PROMPT" INFO_ABOUT_GUIDANCE_SCALE = "INFO_ABOUT_GUIDANCE_SCALE" INFO_ABOUT_INFERENCE_STEPS = "INFO_ABOUT_INFERENCE_STEPS" EXAMPLE_TEXT_PROMPTS = [ "No acute cardiopulmonary abnormality.", "Normal chest radiograph.", "No acute intrathoracic process.", "Mild pulmonary edema.", "No focal consolidation concerning for pneumonia", "No radiographic evidence for acute cardiopulmonary process", ] def load_adapted_unet(unet_pretraining_type, pipe): """ Loads the adapted U-Net for the selected PEFT Type Parameters: unet_pretraining_type (str): The type of PEFT to use for generating the X-ray pipe (StableDiffusionPipeline): The Stable Diffusion Pipeline to use for generating the X-ray Returns: None """ sd_folder_path = "runwayml/stable-diffusion-v1-5" exp_path = '' if unet_pretraining_type == "freeze": pass elif unet_pretraining_type == "svdiff": print("SV-DIFF UNET") pipe.unet = load_unet_for_svdiff( sd_folder_path, spectral_shifts_ckpt=os.path.join( os.path.join(exp_path, "unet"), "spectral_shifts.safetensors" ), subfolder="unet", ) for module in pipe.unet.modules(): if hasattr(module, "perform_svd"): module.perform_svd() elif unet_pretraining_type == "lorav2": exp_path = os.path.join(exp_path, "pytorch_lora_weights.safetensors") pipe.unet.load_attn_procs(exp_path) else: # exp_path = unet_pretraining_type + "_" + "diffusion_pytorch_model.safetensors" # state_dict = load_file(exp_path) state_dict = load_file(unet_pretraining_type + "_" + "diffusion_pytorch_model.safetensors") print(pipe.unet.load_state_dict(state_dict, strict=False)) def loadSDModel(unet_pretraining_type, cuda_device): """ Loads the Stable Diffusion Model for the selected PEFT Type Parameters: unet_pretraining_type (str): The type of PEFT to use for generating the X-ray cuda_device (str): The CUDA device to use for generating the X-ray Returns: pipe (StableDiffusionPipeline): The Stable Diffusion Pipeline to use for generating the X-ray """ sd_folder_path = "runwayml/stable-diffusion-v1-5" pipe = StableDiffusionPipeline.from_pretrained(sd_folder_path, revision="fp16") load_adapted_unet(unet_pretraining_type, pipe) pipe.safety_checker = None return pipe def predict( unet_pretraining_type, input_text, guidance_scale=4, num_inference_steps=75, device="0", OUTPUT_DIR="OUTPUT", ): BARPLOT_TITLE = "Tunable Parameters for {} Fine-Tuning".format(unet_pretraining_type) NUM_TUNABLE_PARAMS = { "full": 86, "attention": 26.7, "bias": 0.343, "norm": 0.2, "norm_bias_attention": 26.7, "lorav2": 0.8, "svdiff": 0.222, "difffit": 0.581, } cuda_device = f"cuda:{device}" if torch.cuda.is_available() else "cpu" print("Loading Pipeline for {} Fine-Tuning".format(unet_pretraining_type)) sd_pipeline = loadSDModel( unet_pretraining_type=unet_pretraining_type, cuda_device=cuda_device, ) sd_pipeline.to(cuda_device) result_image = sd_pipeline( prompt=input_text, height=224, width=224, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, ) result_pil_image = result_image["images"][0] # Create a Bar Plot displaying the number of tunable parameters for the selected PEFT Type df = pd.DataFrame( { "Fine-Tuning Strategy": list(NUM_TUNABLE_PARAMS.keys()), "Number of Tunable Parameters": list(NUM_TUNABLE_PARAMS.values()), } ) df = df[df["Fine-Tuning Strategy"].isin(["full", unet_pretraining_type])].reset_index( drop=True ) bar_plot = gr.BarPlot( value=df, x="Fine-Tuning Strategy", y="Number of Tunable Parameters", title=BARPLOT_TITLE, vertical=True, height=300, width=300, interactive=True, ) return result_pil_image, bar_plot # Create a Gradio interface """ Input Parameters: 1. PEFT Type: (Dropdown) The type of PEFT to use for generating the X-ray 2. Input Text: (Textbox) The text prompt to use for generating the X-ray 3. Guidance Scale: (Slider) The guidance scale to use for generating the X-ray 4. Num Inference Steps: (Slider) The number of inference steps to use for generating the X-ray Output Parameters: 1. Generated X-ray Image: (Image) The generated X-ray image 2. Number of Tunable Parameters: (Bar Plot) The number of tunable parameters for the selected PEFT Type """ iface = gr.Interface( fn=predict, inputs=[ gr.Dropdown( ["full", "difffit", "norm", "bias", "attention", "norm_bias_attention"], value="full", label="PEFT Type", ), gr.Dropdown( EXAMPLE_TEXT_PROMPTS, info=INFO_ABOUT_TEXT_PROMPT, label="Input Text", value=EXAMPLE_TEXT_PROMPTS[0] ), gr.Slider( minimum=1, maximum=10, value=4, step=1, info=INFO_ABOUT_GUIDANCE_SCALE, label="Guidance Scale", ), gr.Slider( minimum=1, maximum=100, value=75, step=1, info=INFO_ABOUT_INFERENCE_STEPS, label="Num Inference Steps", ), ], outputs=[gr.Image(type="pil"), gr.BarPlot()], live=True, analytics_enabled=False, title=TITLE, ) # Launch the Gradio interface iface.launch(share=True)