Spaces:
Runtime error
Runtime error
JosephCatrambone
commited on
Commit
·
b6eec52
1
Parent(s):
2b6d53c
Basic version is complete and working.
Browse files- app.py +16 -8
- pickle_lama_model.ipynb +12 -2
app.py
CHANGED
@@ -1,17 +1,25 @@
|
|
1 |
import gradio as gr
|
2 |
-
|
|
|
3 |
|
4 |
-
|
5 |
|
6 |
-
def predict(input_img):
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
9 |
|
10 |
gradio_app = gr.Interface(
|
11 |
predict,
|
12 |
-
inputs=
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
15 |
)
|
16 |
|
17 |
if __name__ == "__main__":
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
|
4 |
|
5 |
+
model = torch.jit.load("models/lama.pt")
|
6 |
|
7 |
+
def predict(input_img, input_mask):
|
8 |
+
# numpy gives the image as (w,h,c)
|
9 |
+
# Image shape should be (1, 3, 512, 512) and be in the range 0-1.
|
10 |
+
# Mask shape should be (1, 1, 512, 512) AND have values 0.0 or 1.0, not in-between.
|
11 |
+
#out = model(torch.tensor(input_img[None, (2,0,1), :, :])/255.0, torch.tensor(1 * (input_mask[:,:,0] > 0)).unsqueeze(0))
|
12 |
+
out = model((pil_to_tensor(input_img.convert('RGB')) / 255.0).unsqueeze(0), 1 * (pil_to_tensor(input_mask.convert('L')) > 0).unsqueeze(0))[0]
|
13 |
+
return to_pil_image(out)
|
14 |
|
15 |
gradio_app = gr.Interface(
|
16 |
predict,
|
17 |
+
inputs=[
|
18 |
+
gr.Image(label="Select Base Image", sources=['upload',], type="pil"),
|
19 |
+
gr.Image(label="Select Image Mask (White will be inpainted)", sources=['upload',], type="pil"),
|
20 |
+
],
|
21 |
+
outputs=[gr.Image(label="Inpainted Image"),],
|
22 |
+
title="LAMA Inpainting",
|
23 |
)
|
24 |
|
25 |
if __name__ == "__main__":
|
pickle_lama_model.ipynb
CHANGED
@@ -160,7 +160,7 @@
|
|
160 |
},
|
161 |
{
|
162 |
"cell_type": "code",
|
163 |
-
"execution_count":
|
164 |
"id": "163db07c-93a3-40d2-837d-4fade79b07f0",
|
165 |
"metadata": {},
|
166 |
"outputs": [
|
@@ -181,12 +181,22 @@
|
|
181 |
},
|
182 |
"metadata": {},
|
183 |
"output_type": "display_data"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
}
|
185 |
],
|
186 |
"source": [
|
187 |
"print(out['predicted_image'].shape)\n",
|
188 |
"import numpy\n",
|
189 |
-
"display(tvf.to_pil_image((out['predicted_image'])[0]))"
|
|
|
|
|
190 |
]
|
191 |
},
|
192 |
{
|
|
|
160 |
},
|
161 |
{
|
162 |
"cell_type": "code",
|
163 |
+
"execution_count": 76,
|
164 |
"id": "163db07c-93a3-40d2-837d-4fade79b07f0",
|
165 |
"metadata": {},
|
166 |
"outputs": [
|
|
|
181 |
},
|
182 |
"metadata": {},
|
183 |
"output_type": "display_data"
|
184 |
+
},
|
185 |
+
{
|
186 |
+
"name": "stdout",
|
187 |
+
"output_type": "stream",
|
188 |
+
"text": [
|
189 |
+
"tensor(1.)\n",
|
190 |
+
"tensor(1)\n"
|
191 |
+
]
|
192 |
}
|
193 |
],
|
194 |
"source": [
|
195 |
"print(out['predicted_image'].shape)\n",
|
196 |
"import numpy\n",
|
197 |
+
"display(tvf.to_pil_image((out['predicted_image'])[0]))\n",
|
198 |
+
"print(torch.max(image))\n",
|
199 |
+
"print(torch.max(mask))"
|
200 |
]
|
201 |
},
|
202 |
{
|