import gradio as gr import PIL.Image from pathlib import Path import pandas as pd from diffusers.pipelines import StableDiffusionPipeline import torch import argparse import os import warnings from safetensors.torch import load_file import yaml warnings.filterwarnings("ignore") OUTPUT_DIR = "OUTPUT" cuda_device = 1 device = f"cuda:{cuda_device}" if torch.cuda.is_available() else "cpu" TITLE = "Demo for Generating Chest X-rays using Diferent Parameter-Efficient Fine-Tuned Stable Diffusion Pipelines" INFO_ABOUT_TEXT_PROMPT = "INFO_ABOUT_TEXT_PROMPT" INFO_ABOUT_GUIDANCE_SCALE = "INFO_ABOUT_GUIDANCE_SCALE" INFO_ABOUT_INFERENCE_STEPS = "INFO_ABOUT_INFERENCE_STEPS" EXAMPLE_TEXT_PROMPTS = [ "No acute cardiopulmonary abnormality.", "Normal chest radiograph.", "No acute intrathoracic process.", "Mild pulmonary edema.", "No focal consolidation concerning for pneumonia", "No radiographic evidence for acute cardiopulmonary process", ] def load_adapted_unet(unet_pretraining_type, pipe): """ Loads the adapted U-Net for the selected PEFT Type Parameters: unet_pretraining_type (str): The type of PEFT to use for generating the X-ray pipe (StableDiffusionPipeline): The Stable Diffusion Pipeline to use for generating the X-ray Returns: None """ sd_folder_path = "runwayml/stable-diffusion-v1-5" exp_path = '' if unet_pretraining_type == "freeze": pass elif unet_pretraining_type == "svdiff": print("SV-DIFF UNET") pipe.unet = load_unet_for_svdiff( sd_folder_path, spectral_shifts_ckpt=os.path.join( os.path.join(exp_path, "unet"), "spectral_shifts.safetensors" ), subfolder="unet", ) for module in pipe.unet.modules(): if hasattr(module, "perform_svd"): module.perform_svd() elif unet_pretraining_type == "lorav2": exp_path = os.path.join(exp_path, "pytorch_lora_weights.safetensors") pipe.unet.load_attn_procs(exp_path) else: # exp_path = unet_pretraining_type + "_" + "diffusion_pytorch_model.safetensors" # state_dict = load_file(exp_path) state_dict = load_file(unet_pretraining_type + "_" + "diffusion_pytorch_model.safetensors") print(pipe.unet.load_state_dict(state_dict, strict=False)) def loadSDModel(unet_pretraining_type, cuda_device): """ Loads the Stable Diffusion Model for the selected PEFT Type Parameters: unet_pretraining_type (str): The type of PEFT to use for generating the X-ray cuda_device (str): The CUDA device to use for generating the X-ray Returns: pipe (StableDiffusionPipeline): The Stable Diffusion Pipeline to use for generating the X-ray """ sd_folder_path = "runwayml/stable-diffusion-v1-5" pipe = StableDiffusionPipeline.from_pretrained(sd_folder_path, revision="fp16") load_adapted_unet(unet_pretraining_type, pipe) pipe.safety_checker = None return pipe # def load_all_pipelines(): # """ # Loads all the Stable Diffusion Pipelines for each PEFT Type for efficient caching (Design Choice 2) # Parameters: # None # Returns: # sd_pipeline_full (StableDiffusionPipeline): The Stable Diffusion Pipeline for Full Fine-Tuning # sd_pipeline_norm (StableDiffusionPipeline): The Stable Diffusion Pipeline for Norm Fine-Tuning # sd_pipeline_bias (StableDiffusionPipeline): The Stable Diffusion Pipeline for Bias Fine-Tuning # sd_pipeline_attention (StableDiffusionPipeline): The Stable Diffusion Pipeline for Attention Fine-Tuning # sd_pipeline_NBA (StableDiffusionPipeline): The Stable Diffusion Pipeline for NBA Fine-Tuning # sd_pipeline_difffit (StableDiffusionPipeline): The Stable Diffusion Pipeline for Difffit Fine-Tuning # """ # # Dictionary containing the path to the best trained models for each PEFT type # MODEL_PATH_DICT = { # "full": "full_diffusion_pytorch_model.safetensors", # "norm": "norm_diffusion_pytorch_model.safetensors", # "bias": "bias_diffusion_pytorch_model.safetensors", # "attention": "attention_diffusion_pytorch_model.safetensors", # "norm_bias_attention": "norm_bias_attention_diffusion_pytorch_model.safetensors", # "difffit": "difffit_diffusion_pytorch_model.safetensors", # } # device = "0" # cuda_device = f"cuda:{device}" if torch.cuda.is_available() else "cpu" # # Full FT # unet_pretraining_type = "full" # print("Loading Pipeline for Full Fine-Tuning") # sd_pipeline_full = loadSDModel( # unet_pretraining_type=unet_pretraining_type, # exp_path=MODEL_PATH_DICT[unet_pretraining_type], # cuda_device=cuda_device, # ) # # Norm # unet_pretraining_type = "norm" # print("Loading Pipeline for Norm Fine-Tuning") # sd_pipeline_norm = loadSDModel( # unet_pretraining_type=unet_pretraining_type, # exp_path=MODEL_PATH_DICT[unet_pretraining_type], # cuda_device=cuda_device, # ) # # bias # unet_pretraining_type = "bias" # print("Loading Pipeline for Bias Fine-Tuning") # sd_pipeline_bias = loadSDModel( # unet_pretraining_type=unet_pretraining_type, # exp_path=MODEL_PATH_DICT[unet_pretraining_type], # cuda_device=cuda_device, # ) # # attention # unet_pretraining_type = "attention" # print("Loading Pipeline for Attention Fine-Tuning") # sd_pipeline_attention = loadSDModel( # unet_pretraining_type=unet_pretraining_type, # exp_path=MODEL_PATH_DICT[unet_pretraining_type], # cuda_device=cuda_device, # ) # # NBA # unet_pretraining_type = "norm_bias_attention" # print("Loading Pipeline for NBA Fine-Tuning") # sd_pipeline_NBA = loadSDModel( # unet_pretraining_type=unet_pretraining_type, # exp_path=MODEL_PATH_DICT[unet_pretraining_type], # cuda_device=cuda_device, # ) # # difffit # unet_pretraining_type = "difffit" # print("Loading Pipeline for Difffit Fine-Tuning") # sd_pipeline_difffit = loadSDModel( # unet_pretraining_type=unet_pretraining_type, # exp_path=MODEL_PATH_DICT[unet_pretraining_type], # cuda_device=cuda_device, # ) # return ( # sd_pipeline_full, # sd_pipeline_norm, # sd_pipeline_bias, # sd_pipeline_attention, # sd_pipeline_NBA, # sd_pipeline_difffit, # ) # LOAD ALL PIPELINES FIRST AND CACHE THEM # ( # sd_pipeline_full, # sd_pipeline_norm, # sd_pipeline_bias, # sd_pipeline_attention, # sd_pipeline_NBA, # sd_pipeline_difffit, # ) = load_all_pipelines() # PIPELINE_DICT = { # "full": sd_pipeline_full, # "norm": sd_pipeline_norm, # "bias": sd_pipeline_bias, # "attention": sd_pipeline_attention, # "norm_bias_attention": sd_pipeline_NBA, # "difffit": sd_pipeline_difffit, # } def predict( unet_pretraining_type, input_text, guidance_scale=4, num_inference_steps=75, device="0", OUTPUT_DIR="OUTPUT", ): NUM_TUNABLE_PARAMS = { "full": 86, "attention": 26.7, "bias": 0.343, "norm": 0.2, "norm_bias_attention": 26.7, "lorav2": 0.8, "svdiff": 0.222, "difffit": 0.581, } cuda_device = f"cuda:{device}" if torch.cuda.is_available() else "cpu" print("Loading Pipeline for {} Fine-Tuning".format(unet_pretraining_type)) sd_pipeline = loadSDModel( unet_pretraining_type=unet_pretraining_type, cuda_device=cuda_device, ) sd_pipeline.to(cuda_device) result_image = sd_pipeline( prompt=input_text, height=224, width=224, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, ) result_pil_image = result_image["images"][0] # Create a Bar Plot displaying the number of tunable parameters for the selected PEFT Type # Create a Pandas DataFrame df = pd.DataFrame( { "PEFT Type": list(NUM_TUNABLE_PARAMS.keys()), "Number of Tunable Parameters": list(NUM_TUNABLE_PARAMS.values()), } ) df = df[df["PEFT Type"].isin(["full", unet_pretraining_type])].reset_index( drop=True ) bar_plot = gr.BarPlot( value=df, x="PEFT Type", y="Number of Tunable Parameters", label="PEFT Type", title="Number of Tunable Parameters", vertical=True, ) return result_pil_image, bar_plot # Create a Gradio interface """ Input Parameters: 1. PEFT Type: (Dropdown) The type of PEFT to use for generating the X-ray 2. Input Text: (Textbox) The text prompt to use for generating the X-ray 3. Guidance Scale: (Slider) The guidance scale to use for generating the X-ray 4. Num Inference Steps: (Slider) The number of inference steps to use for generating the X-ray Output Parameters: 1. Generated X-ray Image: (Image) The generated X-ray image 2. Number of Tunable Parameters: (Bar Plot) The number of tunable parameters for the selected PEFT Type """ iface = gr.Interface( fn=predict, inputs=[ gr.Dropdown( ["full", "difffit", "norm", "bias", "attention", "norm_bias_attention"], value="full", label="PEFT Type", ), gr.Dropdown( EXAMPLE_TEXT_PROMPTS, info=INFO_ABOUT_TEXT_PROMPT, label="Input Text", value=EXAMPLE_TEXT_PROMPTS[0] ), gr.Slider( minimum=1, maximum=10, value=4, step=1, info=INFO_ABOUT_GUIDANCE_SCALE, label="Guidance Scale", ), gr.Slider( minimum=1, maximum=100, value=75, step=1, info=INFO_ABOUT_INFERENCE_STEPS, label="Num Inference Steps", ), ], outputs=[gr.Image(type="pil"), gr.BarPlot()], live=True, analytics_enabled=False, title=TITLE, ) # Launch the Gradio interface iface.launch(share=True)