curt-park commited on
Commit
064ed26
1 Parent(s): 8d4a5a4

Remove backgrounds from cropped images

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -57,7 +57,7 @@ def get_scores(crops: List[PIL.Image.Image], query: str) -> torch.Tensor:
57
  txt_features = model.encode_text(token)
58
  img_features /= img_features.norm(dim=-1, keepdim=True)
59
  txt_features /= txt_features.norm(dim=-1, keepdim=True)
60
- similarity = (100.0 * img_features @ txt_features.T).softmax(dim=0)
61
  return similarity
62
 
63
 
@@ -82,9 +82,10 @@ def filter_masks(
82
  filtered_masks.append(mask)
83
 
84
  x, y, w, h = mask["bbox"]
85
- crop = image[y: y + h, x: x + w]
 
86
  crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
87
- crop = PIL.Image.fromarray(np.uint8(crop * 255)).convert("RGB")
88
  crop.resize((CLIP_WIDTH, CLIP_HEIGHT))
89
  cropped_masks.append(crop)
90
 
@@ -141,7 +142,7 @@ def segment(
141
  )
142
  image = draw_masks(image, masks)
143
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
144
- image = PIL.Image.fromarray(np.uint8(image)).convert("RGB")
145
  return image
146
 
147
 
@@ -161,14 +162,14 @@ demo = gr.Interface(
161
  [
162
  0.9,
163
  0.8,
164
- 0.15,
165
  os.path.join(os.path.dirname(__file__), "examples/dog.jpg"),
166
- "A dog only",
167
  ],
168
  [
169
  0.9,
170
  0.8,
171
- 0.1,
172
  os.path.join(os.path.dirname(__file__), "examples/city.jpg"),
173
  "A bridge on the water",
174
  ],
@@ -177,7 +178,7 @@ demo = gr.Interface(
177
  0.8,
178
  0.05,
179
  os.path.join(os.path.dirname(__file__), "examples/food.jpg"),
180
- "",
181
  ],
182
  [
183
  0.9,
 
57
  txt_features = model.encode_text(token)
58
  img_features /= img_features.norm(dim=-1, keepdim=True)
59
  txt_features /= txt_features.norm(dim=-1, keepdim=True)
60
+ similarity = (100 * img_features @ txt_features.T).softmax(0)
61
  return similarity
62
 
63
 
 
82
  filtered_masks.append(mask)
83
 
84
  x, y, w, h = mask["bbox"]
85
+ masked = image * np.expand_dims(mask["segmentation"], -1)
86
+ crop = masked[y: y + h, x: x + w]
87
  crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
88
+ crop = PIL.Image.fromarray(crop * 255)
89
  crop.resize((CLIP_WIDTH, CLIP_HEIGHT))
90
  cropped_masks.append(crop)
91
 
 
142
  )
143
  image = draw_masks(image, masks)
144
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
145
+ image = PIL.Image.fromarray(image)
146
  return image
147
 
148
 
 
162
  [
163
  0.9,
164
  0.8,
165
+ 0.30,
166
  os.path.join(os.path.dirname(__file__), "examples/dog.jpg"),
167
+ "A dog",
168
  ],
169
  [
170
  0.9,
171
  0.8,
172
+ 0.05,
173
  os.path.join(os.path.dirname(__file__), "examples/city.jpg"),
174
  "A bridge on the water",
175
  ],
 
178
  0.8,
179
  0.05,
180
  os.path.join(os.path.dirname(__file__), "examples/food.jpg"),
181
+ "spoon",
182
  ],
183
  [
184
  0.9,