Raman Dutt commited on
Commit
ee6eca1
·
1 Parent(s): 740ee27

app.py added

Browse files
Files changed (1) hide show
  1. app.py +329 -0
app.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import PIL.Image
3
+ from pathlib import Path
4
+ import pandas as pd
5
+ from diffusers.pipelines import StableDiffusionPipeline
6
+ import torch
7
+ import argparse
8
+ import os
9
+ import warnings
10
+ from safetensors.torch import load_file
11
+ import yaml
12
+
13
+ warnings.filterwarnings("ignore")
14
+
15
+ OUTPUT_DIR = "OUTPUT"
16
+ cuda_device = 1
17
+ device = f"cuda:{cuda_device}" if torch.cuda.is_available() else "cpu"
18
+
19
+ TITLE = "Demo for Generating Chest X-rays using Diferent Parameter-Efficient Fine-Tuned Stable Diffusion Pipelines"
20
+ INFO_ABOUT_TEXT_PROMPT = "INFO_ABOUT_TEXT_PROMPT"
21
+ INFO_ABOUT_GUIDANCE_SCALE = "INFO_ABOUT_GUIDANCE_SCALE"
22
+ INFO_ABOUT_INFERENCE_STEPS = "INFO_ABOUT_INFERENCE_STEPS"
23
+ EXAMPLE_TEXT_PROMPTS = [
24
+ "No acute cardiopulmonary abnormality.",
25
+ "Normal chest radiograph.",
26
+ "No acute intrathoracic process.",
27
+ "Mild pulmonary edema.",
28
+ "No focal consolidation concerning for pneumonia",
29
+ "No radiographic evidence for acute cardiopulmonary process",
30
+ ]
31
+
32
+
33
+ def load_adapted_unet(unet_pretraining_type, exp_path, pipe):
34
+
35
+ """
36
+ Loads the adapted U-Net for the selected PEFT Type
37
+
38
+ Parameters:
39
+ unet_pretraining_type (str): The type of PEFT to use for generating the X-ray
40
+ exp_path (str): The path to the best trained model for the selected PEFT Type
41
+ pipe (StableDiffusionPipeline): The Stable Diffusion Pipeline to use for generating the X-ray
42
+
43
+ Returns:
44
+ None
45
+ """
46
+
47
+ sd_folder_path = "runwayml/stable-diffusion-v1-5"
48
+
49
+ if unet_pretraining_type == "freeze":
50
+ pass
51
+
52
+ elif unet_pretraining_type == "svdiff":
53
+ print("SV-DIFF UNET")
54
+
55
+ pipe.unet = load_unet_for_svdiff(
56
+ sd_folder_path,
57
+ spectral_shifts_ckpt=os.path.join(
58
+ os.path.join(exp_path, "unet"), "spectral_shifts.safetensors"
59
+ ),
60
+ subfolder="unet",
61
+ )
62
+ for module in pipe.unet.modules():
63
+ if hasattr(module, "perform_svd"):
64
+ module.perform_svd()
65
+
66
+ elif unet_pretraining_type == "lorav2":
67
+ exp_path = os.path.join(exp_path, "pytorch_lora_weights.safetensors")
68
+ pipe.unet.load_attn_procs(exp_path)
69
+ else:
70
+ exp_path = unet_pretraining_type + "_" + "diffusion_pytorch_model.safetensors"
71
+ state_dict = load_file(exp_path)
72
+ print(pipe.unet.load_state_dict(state_dict, strict=False))
73
+
74
+
75
+ def loadSDModel(unet_pretraining_type, exp_path, cuda_device):
76
+
77
+ """
78
+ Loads the Stable Diffusion Model for the selected PEFT Type
79
+
80
+ Parameters:
81
+ unet_pretraining_type (str): The type of PEFT to use for generating the X-ray
82
+ exp_path (str): The path to the best trained model for the selected PEFT Type
83
+ cuda_device (str): The CUDA device to use for generating the X-ray
84
+
85
+ Returns:
86
+ pipe (StableDiffusionPipeline): The Stable Diffusion Pipeline to use for generating the X-ray
87
+ """
88
+
89
+ sd_folder_path = "runwayml/stable-diffusion-v1-5"
90
+
91
+ pipe = StableDiffusionPipeline.from_pretrained(sd_folder_path, revision="fp16")
92
+
93
+ load_adapted_unet(unet_pretraining_type, exp_path, pipe)
94
+ pipe.safety_checker = None
95
+
96
+ return pipe
97
+
98
+
99
+ def load_all_pipelines():
100
+
101
+ """
102
+ Loads all the Stable Diffusion Pipelines for each PEFT Type for efficient caching (Design Choice 2)
103
+
104
+ Parameters:
105
+ None
106
+
107
+ Returns:
108
+ sd_pipeline_full (StableDiffusionPipeline): The Stable Diffusion Pipeline for Full Fine-Tuning
109
+ sd_pipeline_norm (StableDiffusionPipeline): The Stable Diffusion Pipeline for Norm Fine-Tuning
110
+ sd_pipeline_bias (StableDiffusionPipeline): The Stable Diffusion Pipeline for Bias Fine-Tuning
111
+ sd_pipeline_attention (StableDiffusionPipeline): The Stable Diffusion Pipeline for Attention Fine-Tuning
112
+ sd_pipeline_NBA (StableDiffusionPipeline): The Stable Diffusion Pipeline for NBA Fine-Tuning
113
+ sd_pipeline_difffit (StableDiffusionPipeline): The Stable Diffusion Pipeline for Difffit Fine-Tuning
114
+ """
115
+
116
+ # Dictionary containing the path to the best trained models for each PEFT type
117
+ MODEL_PATH_DICT = {
118
+ "full": "full_diffusion_pytorch_model.safetensors",
119
+ "norm": "norm_diffusion_pytorch_model.safetensors",
120
+ "bias": "bias_diffusion_pytorch_model.safetensors",
121
+ "attention": "attention_diffusion_pytorch_model.safetensors",
122
+ "norm_bias_attention": "norm_bias_attention_diffusion_pytorch_model.safetensors",
123
+ "difffit": "difffit_diffusion_pytorch_model.safetensors",
124
+ }
125
+
126
+ device = "0"
127
+ cuda_device = f"cuda:{device}" if torch.cuda.is_available() else "cpu"
128
+
129
+ # Full FT
130
+ unet_pretraining_type = "full"
131
+ print("Loading Pipeline for Full Fine-Tuning")
132
+ sd_pipeline_full = loadSDModel(
133
+ unet_pretraining_type=unet_pretraining_type,
134
+ exp_path=MODEL_PATH_DICT[unet_pretraining_type],
135
+ cuda_device=cuda_device,
136
+ )
137
+
138
+ # Norm
139
+ unet_pretraining_type = "norm"
140
+ print("Loading Pipeline for Norm Fine-Tuning")
141
+ sd_pipeline_norm = loadSDModel(
142
+ unet_pretraining_type=unet_pretraining_type,
143
+ exp_path=MODEL_PATH_DICT[unet_pretraining_type],
144
+ cuda_device=cuda_device,
145
+ )
146
+
147
+ # bias
148
+ unet_pretraining_type = "bias"
149
+ print("Loading Pipeline for Bias Fine-Tuning")
150
+ sd_pipeline_bias = loadSDModel(
151
+ unet_pretraining_type=unet_pretraining_type,
152
+ exp_path=MODEL_PATH_DICT[unet_pretraining_type],
153
+ cuda_device=cuda_device,
154
+ )
155
+
156
+ # attention
157
+ unet_pretraining_type = "attention"
158
+ print("Loading Pipeline for Attention Fine-Tuning")
159
+ sd_pipeline_attention = loadSDModel(
160
+ unet_pretraining_type=unet_pretraining_type,
161
+ exp_path=MODEL_PATH_DICT[unet_pretraining_type],
162
+ cuda_device=cuda_device,
163
+ )
164
+
165
+ # NBA
166
+ unet_pretraining_type = "norm_bias_attention"
167
+ print("Loading Pipeline for NBA Fine-Tuning")
168
+ sd_pipeline_NBA = loadSDModel(
169
+ unet_pretraining_type=unet_pretraining_type,
170
+ exp_path=MODEL_PATH_DICT[unet_pretraining_type],
171
+ cuda_device=cuda_device,
172
+ )
173
+
174
+ # difffit
175
+ unet_pretraining_type = "difffit"
176
+ print("Loading Pipeline for Difffit Fine-Tuning")
177
+ sd_pipeline_difffit = loadSDModel(
178
+ unet_pretraining_type=unet_pretraining_type,
179
+ exp_path=MODEL_PATH_DICT[unet_pretraining_type],
180
+ cuda_device=cuda_device,
181
+ )
182
+
183
+ return (
184
+ sd_pipeline_full,
185
+ sd_pipeline_norm,
186
+ sd_pipeline_bias,
187
+ sd_pipeline_attention,
188
+ sd_pipeline_NBA,
189
+ sd_pipeline_difffit,
190
+ )
191
+
192
+
193
+ # LOAD ALL PIPELINES FIRST AND CACHE THEM
194
+ # (
195
+ # sd_pipeline_full,
196
+ # sd_pipeline_norm,
197
+ # sd_pipeline_bias,
198
+ # sd_pipeline_attention,
199
+ # sd_pipeline_NBA,
200
+ # sd_pipeline_difffit,
201
+ # ) = load_all_pipelines()
202
+
203
+ # PIPELINE_DICT = {
204
+ # "full": sd_pipeline_full,
205
+ # "norm": sd_pipeline_norm,
206
+ # "bias": sd_pipeline_bias,
207
+ # "attention": sd_pipeline_attention,
208
+ # "norm_bias_attention": sd_pipeline_NBA,
209
+ # "difffit": sd_pipeline_difffit,
210
+ # }
211
+
212
+
213
+ def predict(
214
+ unet_pretraining_type,
215
+ input_text,
216
+ guidance_scale=4,
217
+ num_inference_steps=75,
218
+ device="0",
219
+ OUTPUT_DIR="OUTPUT",
220
+ PIPELINE_DICT=PIPELINE_DICT,
221
+ ):
222
+
223
+ NUM_TUNABLE_PARAMS = {
224
+ "full": 86,
225
+ "attention": 26.7,
226
+ "bias": 0.343,
227
+ "norm": 0.2,
228
+ "norm_bias_attention": 26.7,
229
+ "lorav2": 0.8,
230
+ "svdiff": 0.222,
231
+ "difffit": 0.581,
232
+ }
233
+
234
+ cuda_device = f"cuda:{device}" if torch.cuda.is_available() else "cpu"
235
+
236
+
237
+ #sd_pipeline = PIPELINE_DICT[unet_pretraining_type]
238
+ print("Loading Pipeline for {} Fine-Tuning".format(unet_pretraining_type))
239
+ sd_pipeline_norm = loadSDModel(
240
+ unet_pretraining_type=unet_pretraining_type,
241
+ exp_path=MODEL_PATH_DICT[unet_pretraining_type],
242
+ cuda_device=cuda_device,
243
+ )
244
+
245
+ sd_pipeline.to(cuda_device)
246
+
247
+ result_image = sd_pipeline(
248
+ prompt=input_text,
249
+ height=224,
250
+ width=224,
251
+ guidance_scale=guidance_scale,
252
+ num_inference_steps=num_inference_steps,
253
+ )
254
+
255
+ result_pil_image = result_image["images"][0]
256
+
257
+ # Create a Bar Plot displaying the number of tunable parameters for the selected PEFT Type
258
+ # Create a Pandas DataFrame
259
+
260
+ df = pd.DataFrame(
261
+ {
262
+ "PEFT Type": list(NUM_TUNABLE_PARAMS.keys()),
263
+ "Number of Tunable Parameters": list(NUM_TUNABLE_PARAMS.values()),
264
+ }
265
+ )
266
+
267
+ df = df[df["PEFT Type"].isin(["full", unet_pretraining_type])].reset_index(
268
+ drop=True
269
+ )
270
+
271
+ bar_plot = gr.BarPlot(
272
+ value=df,
273
+ x="PEFT Type",
274
+ y="Number of Tunable Parameters",
275
+ label="PEFT Type",
276
+ title="Number of Tunable Parameters",
277
+ vertical=False,
278
+ )
279
+
280
+ return result_pil_image, bar_plot
281
+
282
+
283
+ # Create a Gradio interface
284
+ """
285
+ Input Parameters:
286
+ 1. PEFT Type: (Dropdown) The type of PEFT to use for generating the X-ray
287
+ 2. Input Text: (Textbox) The text prompt to use for generating the X-ray
288
+ 3. Guidance Scale: (Slider) The guidance scale to use for generating the X-ray
289
+ 4. Num Inference Steps: (Slider) The number of inference steps to use for generating the X-ray
290
+
291
+ Output Parameters:
292
+ 1. Generated X-ray Image: (Image) The generated X-ray image
293
+ 2. Number of Tunable Parameters: (Bar Plot) The number of tunable parameters for the selected PEFT Type
294
+ """
295
+ iface = gr.Interface(
296
+ fn=predict,
297
+ inputs=[
298
+ gr.Dropdown(
299
+ ["full", "difffit", "svdiff", "norm", "bias", "attention"],
300
+ label="PEFT Type",
301
+ ),
302
+ gr.Dropdown(
303
+ EXAMPLE_TEXT_PROMPTS, info=INFO_ABOUT_TEXT_PROMPT, label="Input Text"
304
+ ),
305
+ gr.Slider(
306
+ minimum=1,
307
+ maximum=10,
308
+ value=4,
309
+ step=1,
310
+ info=INFO_ABOUT_GUIDANCE_SCALE,
311
+ label="Guidance Scale",
312
+ ),
313
+ gr.Slider(
314
+ minimum=1,
315
+ maximum=100,
316
+ value=75,
317
+ step=1,
318
+ info=INFO_ABOUT_INFERENCE_STEPS,
319
+ label="Num Inference Steps",
320
+ ),
321
+ ],
322
+ outputs=[gr.Image(type="pil"), gr.BarPlot()],
323
+ live=True,
324
+ analytics_enabled=False,
325
+ title=TITLE,
326
+ )
327
+
328
+ # Launch the Gradio interface
329
+ iface.launch(share=True)