|
import gradio as gr |
|
import PIL.Image |
|
from pathlib import Path |
|
import pandas as pd |
|
from diffusers.pipelines import StableDiffusionPipeline |
|
import torch |
|
import argparse |
|
import os |
|
import warnings |
|
from safetensors.torch import load_file |
|
import yaml |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
OUTPUT_DIR = "OUTPUT" |
|
cuda_device = 1 |
|
device = f"cuda:{cuda_device}" if torch.cuda.is_available() else "cpu" |
|
|
|
print("DEVICE: ", device) |
|
|
|
TITLE = "Demo for Generating Chest X-rays using Diferent Parameter-Efficient Fine-Tuned Stable Diffusion Pipelines" |
|
INFO_ABOUT_TEXT_PROMPT = "INFO_ABOUT_TEXT_PROMPT" |
|
INFO_ABOUT_GUIDANCE_SCALE = "INFO_ABOUT_GUIDANCE_SCALE" |
|
INFO_ABOUT_INFERENCE_STEPS = "INFO_ABOUT_INFERENCE_STEPS" |
|
EXAMPLE_TEXT_PROMPTS = [ |
|
"No acute cardiopulmonary abnormality.", |
|
"Normal chest radiograph.", |
|
"No acute intrathoracic process.", |
|
"Mild pulmonary edema.", |
|
"No focal consolidation concerning for pneumonia", |
|
"No radiographic evidence for acute cardiopulmonary process", |
|
] |
|
|
|
|
|
def load_adapted_unet(unet_pretraining_type, pipe): |
|
|
|
""" |
|
Loads the adapted U-Net for the selected PEFT Type |
|
|
|
Parameters: |
|
unet_pretraining_type (str): The type of PEFT to use for generating the X-ray |
|
pipe (StableDiffusionPipeline): The Stable Diffusion Pipeline to use for generating the X-ray |
|
|
|
Returns: |
|
None |
|
""" |
|
|
|
sd_folder_path = "runwayml/stable-diffusion-v1-5" |
|
exp_path = '' |
|
|
|
if unet_pretraining_type == "freeze": |
|
pass |
|
|
|
elif unet_pretraining_type == "svdiff": |
|
print("SV-DIFF UNET") |
|
|
|
pipe.unet = load_unet_for_svdiff( |
|
sd_folder_path, |
|
spectral_shifts_ckpt=os.path.join( |
|
os.path.join(exp_path, "unet"), "spectral_shifts.safetensors" |
|
), |
|
subfolder="unet", |
|
) |
|
for module in pipe.unet.modules(): |
|
if hasattr(module, "perform_svd"): |
|
module.perform_svd() |
|
|
|
elif unet_pretraining_type == "lorav2": |
|
exp_path = os.path.join(exp_path, "pytorch_lora_weights.safetensors") |
|
pipe.unet.load_attn_procs(exp_path) |
|
else: |
|
|
|
|
|
state_dict = load_file(unet_pretraining_type + "_" + "diffusion_pytorch_model.safetensors") |
|
print(pipe.unet.load_state_dict(state_dict, strict=False)) |
|
|
|
|
|
def loadSDModel(unet_pretraining_type, cuda_device): |
|
|
|
""" |
|
Loads the Stable Diffusion Model for the selected PEFT Type |
|
|
|
Parameters: |
|
unet_pretraining_type (str): The type of PEFT to use for generating the X-ray |
|
cuda_device (str): The CUDA device to use for generating the X-ray |
|
|
|
Returns: |
|
pipe (StableDiffusionPipeline): The Stable Diffusion Pipeline to use for generating the X-ray |
|
""" |
|
|
|
sd_folder_path = "runwayml/stable-diffusion-v1-5" |
|
|
|
pipe = StableDiffusionPipeline.from_pretrained(sd_folder_path, revision="fp16") |
|
|
|
load_adapted_unet(unet_pretraining_type, pipe) |
|
pipe.safety_checker = None |
|
|
|
return pipe |
|
|
|
|
|
def predict( |
|
unet_pretraining_type, |
|
input_text, |
|
guidance_scale=4, |
|
num_inference_steps=75, |
|
device="0", |
|
OUTPUT_DIR="OUTPUT", |
|
): |
|
|
|
BARPLOT_TITLE = "Tunable Parameters for {} Fine-Tuning".format(unet_pretraining_type) |
|
NUM_TUNABLE_PARAMS = { |
|
"full": 86, |
|
"attention": 26.7, |
|
"bias": 0.343, |
|
"norm": 0.2, |
|
"norm_bias_attention": 26.7, |
|
"lorav2": 0.8, |
|
"svdiff": 0.222, |
|
"difffit": 0.581, |
|
} |
|
|
|
cuda_device = f"cuda:{device}" if torch.cuda.is_available() else "cpu" |
|
|
|
print("Loading Pipeline for {} Fine-Tuning".format(unet_pretraining_type)) |
|
sd_pipeline = loadSDModel( |
|
unet_pretraining_type=unet_pretraining_type, |
|
cuda_device=cuda_device, |
|
) |
|
|
|
sd_pipeline.to(cuda_device) |
|
|
|
result_image = sd_pipeline( |
|
prompt=input_text, |
|
height=224, |
|
width=224, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
) |
|
|
|
result_pil_image = result_image["images"][0] |
|
|
|
|
|
df = pd.DataFrame( |
|
{ |
|
"Fine-Tuning Strategy": list(NUM_TUNABLE_PARAMS.keys()), |
|
"Number of Tunable Parameters": list(NUM_TUNABLE_PARAMS.values()), |
|
} |
|
) |
|
|
|
df = df[df["PEFT Type"].isin(["full", unet_pretraining_type])].reset_index( |
|
drop=True |
|
) |
|
|
|
bar_plot = gr.BarPlot( |
|
value=df, |
|
x="Fine-Tuning Strategy", |
|
y="Number of Tunable Parameters", |
|
title=BARPLOT_TITLE, |
|
vertical=True, |
|
height=300, |
|
width=300, |
|
interactive=True, |
|
) |
|
|
|
return result_pil_image, bar_plot |
|
|
|
|
|
|
|
""" |
|
Input Parameters: |
|
1. PEFT Type: (Dropdown) The type of PEFT to use for generating the X-ray |
|
2. Input Text: (Textbox) The text prompt to use for generating the X-ray |
|
3. Guidance Scale: (Slider) The guidance scale to use for generating the X-ray |
|
4. Num Inference Steps: (Slider) The number of inference steps to use for generating the X-ray |
|
|
|
Output Parameters: |
|
1. Generated X-ray Image: (Image) The generated X-ray image |
|
2. Number of Tunable Parameters: (Bar Plot) The number of tunable parameters for the selected PEFT Type |
|
""" |
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=[ |
|
gr.Dropdown( |
|
["full", "difffit", "norm", "bias", "attention", "norm_bias_attention"], |
|
value="full", |
|
label="PEFT Type", |
|
), |
|
gr.Dropdown( |
|
EXAMPLE_TEXT_PROMPTS, info=INFO_ABOUT_TEXT_PROMPT, label="Input Text", value=EXAMPLE_TEXT_PROMPTS[0] |
|
), |
|
gr.Slider( |
|
minimum=1, |
|
maximum=10, |
|
value=4, |
|
step=1, |
|
info=INFO_ABOUT_GUIDANCE_SCALE, |
|
label="Guidance Scale", |
|
), |
|
gr.Slider( |
|
minimum=1, |
|
maximum=100, |
|
value=75, |
|
step=1, |
|
info=INFO_ABOUT_INFERENCE_STEPS, |
|
label="Num Inference Steps", |
|
), |
|
], |
|
outputs=[gr.Image(type="pil"), gr.BarPlot()], |
|
live=True, |
|
analytics_enabled=False, |
|
title=TITLE, |
|
) |
|
|
|
|
|
iface.launch(share=True) |
|
|