Spaces:
Runtime error
Runtime error
Raman Dutt
commited on
Commit
·
0dc7eec
1
Parent(s):
3a33055
_predict_using_default_params function added for failing cases
Browse files
app.py
CHANGED
@@ -97,15 +97,16 @@ def loadSDModel(unet_pretraining_type, cuda_device):
|
|
97 |
|
98 |
return pipe
|
99 |
|
|
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
|
110 |
BARPLOT_TITLE = "Tunable Parameters for {} Fine-Tuning".format(unet_pretraining_type)
|
111 |
NUM_TUNABLE_PARAMS = {
|
@@ -120,7 +121,7 @@ def predict(
|
|
120 |
}
|
121 |
|
122 |
cuda_device = f"cuda:{device}" if torch.cuda.is_available() else "cpu"
|
123 |
-
|
124 |
print("Loading Pipeline for {} Fine-Tuning".format(unet_pretraining_type))
|
125 |
sd_pipeline = loadSDModel(
|
126 |
unet_pretraining_type=unet_pretraining_type,
|
@@ -138,7 +139,6 @@ def predict(
|
|
138 |
)
|
139 |
|
140 |
result_pil_image = result_image["images"][0]
|
141 |
-
|
142 |
|
143 |
# Create a Bar Plot displaying the number of tunable parameters for the selected PEFT Type
|
144 |
df = pd.DataFrame(
|
@@ -168,6 +168,80 @@ def predict(
|
|
168 |
return result_pil_image, bar_plot
|
169 |
|
170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
# Create a Gradio interface
|
172 |
"""
|
173 |
Input Parameters:
|
|
|
97 |
|
98 |
return pipe
|
99 |
|
100 |
+
def _predict_using_default_params():
|
101 |
|
102 |
+
|
103 |
+
# Defining the default parameters
|
104 |
+
unet_pretraining_type = 'full'
|
105 |
+
input_text = 'No acute cardiopulmonary abnormality.'
|
106 |
+
guidance_scale = 4
|
107 |
+
num_inference_steps = 75
|
108 |
+
device = '0'
|
109 |
+
OUTPUT_DIR = 'OUTPUT'
|
110 |
|
111 |
BARPLOT_TITLE = "Tunable Parameters for {} Fine-Tuning".format(unet_pretraining_type)
|
112 |
NUM_TUNABLE_PARAMS = {
|
|
|
121 |
}
|
122 |
|
123 |
cuda_device = f"cuda:{device}" if torch.cuda.is_available() else "cpu"
|
124 |
+
|
125 |
print("Loading Pipeline for {} Fine-Tuning".format(unet_pretraining_type))
|
126 |
sd_pipeline = loadSDModel(
|
127 |
unet_pretraining_type=unet_pretraining_type,
|
|
|
139 |
)
|
140 |
|
141 |
result_pil_image = result_image["images"][0]
|
|
|
142 |
|
143 |
# Create a Bar Plot displaying the number of tunable parameters for the selected PEFT Type
|
144 |
df = pd.DataFrame(
|
|
|
168 |
return result_pil_image, bar_plot
|
169 |
|
170 |
|
171 |
+
def predict(
|
172 |
+
unet_pretraining_type,
|
173 |
+
input_text,
|
174 |
+
guidance_scale=4,
|
175 |
+
num_inference_steps=75,
|
176 |
+
device="0",
|
177 |
+
OUTPUT_DIR="OUTPUT",
|
178 |
+
):
|
179 |
+
|
180 |
+
try:
|
181 |
+
BARPLOT_TITLE = "Tunable Parameters for {} Fine-Tuning".format(unet_pretraining_type)
|
182 |
+
NUM_TUNABLE_PARAMS = {
|
183 |
+
"full": 86,
|
184 |
+
"attention": 26.7,
|
185 |
+
"bias": 0.343,
|
186 |
+
"norm": 0.2,
|
187 |
+
"norm_bias_attention": 26.7,
|
188 |
+
"lorav2": 0.8,
|
189 |
+
"svdiff": 0.222,
|
190 |
+
"difffit": 0.581,
|
191 |
+
}
|
192 |
+
|
193 |
+
cuda_device = f"cuda:{device}" if torch.cuda.is_available() else "cpu"
|
194 |
+
|
195 |
+
print("Loading Pipeline for {} Fine-Tuning".format(unet_pretraining_type))
|
196 |
+
sd_pipeline = loadSDModel(
|
197 |
+
unet_pretraining_type=unet_pretraining_type,
|
198 |
+
cuda_device=cuda_device,
|
199 |
+
)
|
200 |
+
|
201 |
+
sd_pipeline.to(cuda_device)
|
202 |
+
|
203 |
+
result_image = sd_pipeline(
|
204 |
+
prompt=input_text,
|
205 |
+
height=224,
|
206 |
+
width=224,
|
207 |
+
guidance_scale=guidance_scale,
|
208 |
+
num_inference_steps=num_inference_steps,
|
209 |
+
)
|
210 |
+
|
211 |
+
result_pil_image = result_image["images"][0]
|
212 |
+
|
213 |
+
|
214 |
+
# Create a Bar Plot displaying the number of tunable parameters for the selected PEFT Type
|
215 |
+
df = pd.DataFrame(
|
216 |
+
{
|
217 |
+
"Fine-Tuning Strategy": list(NUM_TUNABLE_PARAMS.keys()),
|
218 |
+
"Number of Tunable Parameters": list(NUM_TUNABLE_PARAMS.values()),
|
219 |
+
}
|
220 |
+
)
|
221 |
+
|
222 |
+
print(df)
|
223 |
+
|
224 |
+
df = df[df["Fine-Tuning Strategy"].isin(["full", unet_pretraining_type])].reset_index(
|
225 |
+
drop=True
|
226 |
+
)
|
227 |
+
|
228 |
+
bar_plot = gr.BarPlot(
|
229 |
+
value=df,
|
230 |
+
x="Fine-Tuning Strategy",
|
231 |
+
y="Number of Tunable Parameters",
|
232 |
+
title=BARPLOT_TITLE,
|
233 |
+
vertical=True,
|
234 |
+
height=300,
|
235 |
+
width=300,
|
236 |
+
interactive=True,
|
237 |
+
)
|
238 |
+
|
239 |
+
return result_pil_image, bar_plot
|
240 |
+
|
241 |
+
except:
|
242 |
+
return _predict_using_default_params()
|
243 |
+
|
244 |
+
|
245 |
# Create a Gradio interface
|
246 |
"""
|
247 |
Input Parameters:
|