JosephCatrambone commited on
Commit
b6eec52
·
1 Parent(s): 2b6d53c

Basic version is complete and working.

Browse files
Files changed (2) hide show
  1. app.py +16 -8
  2. pickle_lama_model.ipynb +12 -2
app.py CHANGED
@@ -1,17 +1,25 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
3
 
4
- pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
5
 
6
- def predict(input_img):
7
- predictions = pipeline(input_img)
8
- return input_img, {p["label"]: p["score"] for p in predictions}
 
 
 
 
9
 
10
  gradio_app = gr.Interface(
11
  predict,
12
- inputs=gr.Image(label="Select hot dog candidate", sources=['upload', 'webcam'], type="pil"),
13
- outputs=[gr.Image(label="Processed Image"), gr.Label(label="Result", num_top_classes=2)],
14
- title="Hot Dog? Or Not?",
 
 
 
15
  )
16
 
17
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ import torch
3
+ from torchvision.transforms.functional import pil_to_tensor, to_pil_image
4
 
5
+ model = torch.jit.load("models/lama.pt")
6
 
7
+ def predict(input_img, input_mask):
8
+ # numpy gives the image as (w,h,c)
9
+ # Image shape should be (1, 3, 512, 512) and be in the range 0-1.
10
+ # Mask shape should be (1, 1, 512, 512) AND have values 0.0 or 1.0, not in-between.
11
+ #out = model(torch.tensor(input_img[None, (2,0,1), :, :])/255.0, torch.tensor(1 * (input_mask[:,:,0] > 0)).unsqueeze(0))
12
+ out = model((pil_to_tensor(input_img.convert('RGB')) / 255.0).unsqueeze(0), 1 * (pil_to_tensor(input_mask.convert('L')) > 0).unsqueeze(0))[0]
13
+ return to_pil_image(out)
14
 
15
  gradio_app = gr.Interface(
16
  predict,
17
+ inputs=[
18
+ gr.Image(label="Select Base Image", sources=['upload',], type="pil"),
19
+ gr.Image(label="Select Image Mask (White will be inpainted)", sources=['upload',], type="pil"),
20
+ ],
21
+ outputs=[gr.Image(label="Inpainted Image"),],
22
+ title="LAMA Inpainting",
23
  )
24
 
25
  if __name__ == "__main__":
pickle_lama_model.ipynb CHANGED
@@ -160,7 +160,7 @@
160
  },
161
  {
162
  "cell_type": "code",
163
- "execution_count": 69,
164
  "id": "163db07c-93a3-40d2-837d-4fade79b07f0",
165
  "metadata": {},
166
  "outputs": [
@@ -181,12 +181,22 @@
181
  },
182
  "metadata": {},
183
  "output_type": "display_data"
 
 
 
 
 
 
 
 
184
  }
185
  ],
186
  "source": [
187
  "print(out['predicted_image'].shape)\n",
188
  "import numpy\n",
189
- "display(tvf.to_pil_image((out['predicted_image'])[0]))"
 
 
190
  ]
191
  },
192
  {
 
160
  },
161
  {
162
  "cell_type": "code",
163
+ "execution_count": 76,
164
  "id": "163db07c-93a3-40d2-837d-4fade79b07f0",
165
  "metadata": {},
166
  "outputs": [
 
181
  },
182
  "metadata": {},
183
  "output_type": "display_data"
184
+ },
185
+ {
186
+ "name": "stdout",
187
+ "output_type": "stream",
188
+ "text": [
189
+ "tensor(1.)\n",
190
+ "tensor(1)\n"
191
+ ]
192
  }
193
  ],
194
  "source": [
195
  "print(out['predicted_image'].shape)\n",
196
  "import numpy\n",
197
+ "display(tvf.to_pil_image((out['predicted_image'])[0]))\n",
198
+ "print(torch.max(image))\n",
199
+ "print(torch.max(mask))"
200
  ]
201
  },
202
  {