Spaces:
Paused
Paused
alessandro trinca tornidor
commited on
Commit
·
f182d7a
1
Parent(s):
a5b4be9
[refactor] add and use create_placeholder_variables() function
Browse files- README.md +1 -0
- main.py +10 -17
- resources/placeholders/error_happened.png +3 -0
- resources/placeholders/no_seg_out.png +3 -0
- utils/utils.py +14 -0
README.md
CHANGED
@@ -321,3 +321,4 @@ If you find this project useful in your research, please consider citing:
|
|
321 |
|
322 |
## Acknowledgement
|
323 |
- This work is built upon the [LLaVA](https://github.com/haotian-liu/LLaVA) and [SAM](https://github.com/facebookresearch/segment-anything).
|
|
|
|
321 |
|
322 |
## Acknowledgement
|
323 |
- This work is built upon the [LLaVA](https://github.com/haotian-liu/LLaVA) and [SAM](https://github.com/facebookresearch/segment-anything).
|
324 |
+
- placeholders images (error, 'no output segmentation') from Muhammad Khaleeq (https://www.vecteezy.com/members/iyikon)
|
main.py
CHANGED
@@ -20,9 +20,7 @@ from model.LISA import LISAForCausalLM
|
|
20 |
from model.llava import conversation as conversation_lib
|
21 |
from model.llava.mm_utils import tokenizer_image_token
|
22 |
from model.segment_anything.utils.transforms import ResizeLongestSide
|
23 |
-
from utils import constants, session_logger
|
24 |
-
from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
|
25 |
-
DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
|
26 |
|
27 |
|
28 |
session_logger.change_logging(logging.DEBUG)
|
@@ -34,6 +32,7 @@ FASTAPI_STATIC = os.getenv("FASTAPI_STATIC")
|
|
34 |
os.makedirs(FASTAPI_STATIC, exist_ok=True)
|
35 |
app.mount("/static", StaticFiles(directory=FASTAPI_STATIC), name="static")
|
36 |
templates = Jinja2Templates(directory="templates")
|
|
|
37 |
|
38 |
|
39 |
@app.get("/health")
|
@@ -230,6 +229,7 @@ def get_inference_model_by_args(args_to_parse):
|
|
230 |
logging.info(f"args_to_parse:{args_to_parse}, creating model...")
|
231 |
model, clip_image_processor, tokenizer, transform = get_model(args_to_parse)
|
232 |
logging.info("created model, preparing inference function")
|
|
|
233 |
|
234 |
@session_logger.set_uuid_logging
|
235 |
def inference(input_str, input_image):
|
@@ -242,22 +242,19 @@ def get_inference_model_by_args(args_to_parse):
|
|
242 |
## input valid check
|
243 |
if not re.match(r"^[A-Za-z ,.!?\'\"]+$", input_str) or len(input_str) < 1:
|
244 |
output_str = "[Error] Invalid input: ", input_str
|
245 |
-
|
246 |
-
## error happened
|
247 |
-
output_image = cv2.imread("./resources/error_happened.png")[:, :, ::-1]
|
248 |
-
return output_image, output_str
|
249 |
|
250 |
# Model Inference
|
251 |
conv = conversation_lib.conv_templates[args_to_parse.conv_type].copy()
|
252 |
conv.messages = []
|
253 |
|
254 |
prompt = input_str
|
255 |
-
prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
|
256 |
if args_to_parse.use_mm_start_end:
|
257 |
replace_token = (
|
258 |
-
DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
|
259 |
)
|
260 |
-
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
261 |
|
262 |
conv.append_message(conv.roles[0], prompt)
|
263 |
conv.append_message(conv.roles[1], "")
|
@@ -300,7 +297,7 @@ def get_inference_model_by_args(args_to_parse):
|
|
300 |
max_new_tokens=512,
|
301 |
tokenizer=tokenizer,
|
302 |
)
|
303 |
-
output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX]
|
304 |
|
305 |
text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
|
306 |
text_output = text_output.replace("\n", "").replace(" ", " ")
|
@@ -321,12 +318,8 @@ def get_inference_model_by_args(args_to_parse):
|
|
321 |
+ pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
|
322 |
)[pred_mask]
|
323 |
|
324 |
-
output_str = f"
|
325 |
-
if save_img is
|
326 |
-
output_image = save_img # input_image
|
327 |
-
else:
|
328 |
-
## no seg output
|
329 |
-
output_image = cv2.imread("./resources/no_seg_out.png")[:, :, ::-1]
|
330 |
logging.info(f"output_image type: {type(output_image)}.")
|
331 |
return output_image, output_str
|
332 |
|
|
|
20 |
from model.llava import conversation as conversation_lib
|
21 |
from model.llava.mm_utils import tokenizer_image_token
|
22 |
from model.segment_anything.utils.transforms import ResizeLongestSide
|
23 |
+
from utils import constants, session_logger, utils
|
|
|
|
|
24 |
|
25 |
|
26 |
session_logger.change_logging(logging.DEBUG)
|
|
|
32 |
os.makedirs(FASTAPI_STATIC, exist_ok=True)
|
33 |
app.mount("/static", StaticFiles(directory=FASTAPI_STATIC), name="static")
|
34 |
templates = Jinja2Templates(directory="templates")
|
35 |
+
placeholders = utils.create_placeholder_variables()
|
36 |
|
37 |
|
38 |
@app.get("/health")
|
|
|
229 |
logging.info(f"args_to_parse:{args_to_parse}, creating model...")
|
230 |
model, clip_image_processor, tokenizer, transform = get_model(args_to_parse)
|
231 |
logging.info("created model, preparing inference function")
|
232 |
+
no_seg_out, error_happened = placeholders["no_seg_out"], placeholders["error_happened"]
|
233 |
|
234 |
@session_logger.set_uuid_logging
|
235 |
def inference(input_str, input_image):
|
|
|
242 |
## input valid check
|
243 |
if not re.match(r"^[A-Za-z ,.!?\'\"]+$", input_str) or len(input_str) < 1:
|
244 |
output_str = "[Error] Invalid input: ", input_str
|
245 |
+
return error_happened, output_str
|
|
|
|
|
|
|
246 |
|
247 |
# Model Inference
|
248 |
conv = conversation_lib.conv_templates[args_to_parse.conv_type].copy()
|
249 |
conv.messages = []
|
250 |
|
251 |
prompt = input_str
|
252 |
+
prompt = utils.DEFAULT_IMAGE_TOKEN + "\n" + prompt
|
253 |
if args_to_parse.use_mm_start_end:
|
254 |
replace_token = (
|
255 |
+
utils.DEFAULT_IM_START_TOKEN + utils.DEFAULT_IMAGE_TOKEN + utils.DEFAULT_IM_END_TOKEN
|
256 |
)
|
257 |
+
prompt = prompt.replace(utils.DEFAULT_IMAGE_TOKEN, replace_token)
|
258 |
|
259 |
conv.append_message(conv.roles[0], prompt)
|
260 |
conv.append_message(conv.roles[1], "")
|
|
|
297 |
max_new_tokens=512,
|
298 |
tokenizer=tokenizer,
|
299 |
)
|
300 |
+
output_ids = output_ids[0][output_ids[0] != utils.IMAGE_TOKEN_INDEX]
|
301 |
|
302 |
text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
|
303 |
text_output = text_output.replace("\n", "").replace(" ", " ")
|
|
|
318 |
+ pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
|
319 |
)[pred_mask]
|
320 |
|
321 |
+
output_str = f"ASSISTANT: {text_output}"
|
322 |
+
output_image = no_seg_out if save_img is None else save_img
|
|
|
|
|
|
|
|
|
323 |
logging.info(f"output_image type: {type(output_image)}.")
|
324 |
return output_image, output_str
|
325 |
|
resources/placeholders/error_happened.png
ADDED
Git LFS Details
|
resources/placeholders/no_seg_out.png
ADDED
Git LFS Details
|
utils/utils.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
from enum import Enum
|
|
|
2 |
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
|
|
|
7 |
IGNORE_INDEX = -100
|
8 |
IMAGE_TOKEN_INDEX = -200
|
9 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
@@ -40,6 +42,7 @@ ANSWER_LIST = [
|
|
40 |
"Sure, the segmentation result is [SEG].",
|
41 |
"[SEG].",
|
42 |
]
|
|
|
43 |
|
44 |
|
45 |
class Summary(Enum):
|
@@ -161,3 +164,14 @@ def dict_to_cuda(input_dict):
|
|
161 |
):
|
162 |
input_dict[k] = [ele.cuda(non_blocking=True) for ele in v]
|
163 |
return input_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from enum import Enum
|
2 |
+
from pathlib import Path
|
3 |
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
import torch.distributed as dist
|
7 |
|
8 |
+
|
9 |
IGNORE_INDEX = -100
|
10 |
IMAGE_TOKEN_INDEX = -200
|
11 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
|
|
42 |
"Sure, the segmentation result is [SEG].",
|
43 |
"[SEG].",
|
44 |
]
|
45 |
+
ROOT = Path(__file__).parent.parent
|
46 |
|
47 |
|
48 |
class Summary(Enum):
|
|
|
164 |
):
|
165 |
input_dict[k] = [ele.cuda(non_blocking=True) for ele in v]
|
166 |
return input_dict
|
167 |
+
|
168 |
+
|
169 |
+
def create_placeholder_variables():
|
170 |
+
import cv2
|
171 |
+
|
172 |
+
no_seg_out = cv2.imread(str(ROOT / "resources" / "placeholder" / "no_seg_out.png"))[:, :, ::-1]
|
173 |
+
error_happened = cv2.imread(str(ROOT / "resources" / "placeholder" / "error_happened.png"))[:, :, ::-1]
|
174 |
+
return {
|
175 |
+
"no_seg_out": no_seg_out,
|
176 |
+
"error_happened": error_happened
|
177 |
+
}
|