Raman Dutt commited on
Commit
0dc7eec
·
1 Parent(s): 3a33055

_predict_using_default_params function added for failing cases

Browse files
Files changed (1) hide show
  1. app.py +84 -10
app.py CHANGED
@@ -97,15 +97,16 @@ def loadSDModel(unet_pretraining_type, cuda_device):
97
 
98
  return pipe
99
 
 
100
 
101
- def predict(
102
- unet_pretraining_type,
103
- input_text,
104
- guidance_scale=4,
105
- num_inference_steps=75,
106
- device="0",
107
- OUTPUT_DIR="OUTPUT",
108
- ):
109
 
110
  BARPLOT_TITLE = "Tunable Parameters for {} Fine-Tuning".format(unet_pretraining_type)
111
  NUM_TUNABLE_PARAMS = {
@@ -120,7 +121,7 @@ def predict(
120
  }
121
 
122
  cuda_device = f"cuda:{device}" if torch.cuda.is_available() else "cpu"
123
-
124
  print("Loading Pipeline for {} Fine-Tuning".format(unet_pretraining_type))
125
  sd_pipeline = loadSDModel(
126
  unet_pretraining_type=unet_pretraining_type,
@@ -138,7 +139,6 @@ def predict(
138
  )
139
 
140
  result_pil_image = result_image["images"][0]
141
-
142
 
143
  # Create a Bar Plot displaying the number of tunable parameters for the selected PEFT Type
144
  df = pd.DataFrame(
@@ -168,6 +168,80 @@ def predict(
168
  return result_pil_image, bar_plot
169
 
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  # Create a Gradio interface
172
  """
173
  Input Parameters:
 
97
 
98
  return pipe
99
 
100
+ def _predict_using_default_params():
101
 
102
+
103
+ # Defining the default parameters
104
+ unet_pretraining_type = 'full'
105
+ input_text = 'No acute cardiopulmonary abnormality.'
106
+ guidance_scale = 4
107
+ num_inference_steps = 75
108
+ device = '0'
109
+ OUTPUT_DIR = 'OUTPUT'
110
 
111
  BARPLOT_TITLE = "Tunable Parameters for {} Fine-Tuning".format(unet_pretraining_type)
112
  NUM_TUNABLE_PARAMS = {
 
121
  }
122
 
123
  cuda_device = f"cuda:{device}" if torch.cuda.is_available() else "cpu"
124
+
125
  print("Loading Pipeline for {} Fine-Tuning".format(unet_pretraining_type))
126
  sd_pipeline = loadSDModel(
127
  unet_pretraining_type=unet_pretraining_type,
 
139
  )
140
 
141
  result_pil_image = result_image["images"][0]
 
142
 
143
  # Create a Bar Plot displaying the number of tunable parameters for the selected PEFT Type
144
  df = pd.DataFrame(
 
168
  return result_pil_image, bar_plot
169
 
170
 
171
+ def predict(
172
+ unet_pretraining_type,
173
+ input_text,
174
+ guidance_scale=4,
175
+ num_inference_steps=75,
176
+ device="0",
177
+ OUTPUT_DIR="OUTPUT",
178
+ ):
179
+
180
+ try:
181
+ BARPLOT_TITLE = "Tunable Parameters for {} Fine-Tuning".format(unet_pretraining_type)
182
+ NUM_TUNABLE_PARAMS = {
183
+ "full": 86,
184
+ "attention": 26.7,
185
+ "bias": 0.343,
186
+ "norm": 0.2,
187
+ "norm_bias_attention": 26.7,
188
+ "lorav2": 0.8,
189
+ "svdiff": 0.222,
190
+ "difffit": 0.581,
191
+ }
192
+
193
+ cuda_device = f"cuda:{device}" if torch.cuda.is_available() else "cpu"
194
+
195
+ print("Loading Pipeline for {} Fine-Tuning".format(unet_pretraining_type))
196
+ sd_pipeline = loadSDModel(
197
+ unet_pretraining_type=unet_pretraining_type,
198
+ cuda_device=cuda_device,
199
+ )
200
+
201
+ sd_pipeline.to(cuda_device)
202
+
203
+ result_image = sd_pipeline(
204
+ prompt=input_text,
205
+ height=224,
206
+ width=224,
207
+ guidance_scale=guidance_scale,
208
+ num_inference_steps=num_inference_steps,
209
+ )
210
+
211
+ result_pil_image = result_image["images"][0]
212
+
213
+
214
+ # Create a Bar Plot displaying the number of tunable parameters for the selected PEFT Type
215
+ df = pd.DataFrame(
216
+ {
217
+ "Fine-Tuning Strategy": list(NUM_TUNABLE_PARAMS.keys()),
218
+ "Number of Tunable Parameters": list(NUM_TUNABLE_PARAMS.values()),
219
+ }
220
+ )
221
+
222
+ print(df)
223
+
224
+ df = df[df["Fine-Tuning Strategy"].isin(["full", unet_pretraining_type])].reset_index(
225
+ drop=True
226
+ )
227
+
228
+ bar_plot = gr.BarPlot(
229
+ value=df,
230
+ x="Fine-Tuning Strategy",
231
+ y="Number of Tunable Parameters",
232
+ title=BARPLOT_TITLE,
233
+ vertical=True,
234
+ height=300,
235
+ width=300,
236
+ interactive=True,
237
+ )
238
+
239
+ return result_pil_image, bar_plot
240
+
241
+ except:
242
+ return _predict_using_default_params()
243
+
244
+
245
  # Create a Gradio interface
246
  """
247
  Input Parameters: