alessandro trinca tornidor commited on
Commit
f182d7a
·
1 Parent(s): a5b4be9

[refactor] add and use create_placeholder_variables() function

Browse files
README.md CHANGED
@@ -321,3 +321,4 @@ If you find this project useful in your research, please consider citing:
321
 
322
  ## Acknowledgement
323
  - This work is built upon the [LLaVA](https://github.com/haotian-liu/LLaVA) and [SAM](https://github.com/facebookresearch/segment-anything).
 
 
321
 
322
  ## Acknowledgement
323
  - This work is built upon the [LLaVA](https://github.com/haotian-liu/LLaVA) and [SAM](https://github.com/facebookresearch/segment-anything).
324
+ - placeholders images (error, 'no output segmentation') from Muhammad Khaleeq (https://www.vecteezy.com/members/iyikon)
main.py CHANGED
@@ -20,9 +20,7 @@ from model.LISA import LISAForCausalLM
20
  from model.llava import conversation as conversation_lib
21
  from model.llava.mm_utils import tokenizer_image_token
22
  from model.segment_anything.utils.transforms import ResizeLongestSide
23
- from utils import constants, session_logger
24
- from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
25
- DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
26
 
27
 
28
  session_logger.change_logging(logging.DEBUG)
@@ -34,6 +32,7 @@ FASTAPI_STATIC = os.getenv("FASTAPI_STATIC")
34
  os.makedirs(FASTAPI_STATIC, exist_ok=True)
35
  app.mount("/static", StaticFiles(directory=FASTAPI_STATIC), name="static")
36
  templates = Jinja2Templates(directory="templates")
 
37
 
38
 
39
  @app.get("/health")
@@ -230,6 +229,7 @@ def get_inference_model_by_args(args_to_parse):
230
  logging.info(f"args_to_parse:{args_to_parse}, creating model...")
231
  model, clip_image_processor, tokenizer, transform = get_model(args_to_parse)
232
  logging.info("created model, preparing inference function")
 
233
 
234
  @session_logger.set_uuid_logging
235
  def inference(input_str, input_image):
@@ -242,22 +242,19 @@ def get_inference_model_by_args(args_to_parse):
242
  ## input valid check
243
  if not re.match(r"^[A-Za-z ,.!?\'\"]+$", input_str) or len(input_str) < 1:
244
  output_str = "[Error] Invalid input: ", input_str
245
- # output_image = np.zeros((128, 128, 3))
246
- ## error happened
247
- output_image = cv2.imread("./resources/error_happened.png")[:, :, ::-1]
248
- return output_image, output_str
249
 
250
  # Model Inference
251
  conv = conversation_lib.conv_templates[args_to_parse.conv_type].copy()
252
  conv.messages = []
253
 
254
  prompt = input_str
255
- prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
256
  if args_to_parse.use_mm_start_end:
257
  replace_token = (
258
- DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
259
  )
260
- prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
261
 
262
  conv.append_message(conv.roles[0], prompt)
263
  conv.append_message(conv.roles[1], "")
@@ -300,7 +297,7 @@ def get_inference_model_by_args(args_to_parse):
300
  max_new_tokens=512,
301
  tokenizer=tokenizer,
302
  )
303
- output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX]
304
 
305
  text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
306
  text_output = text_output.replace("\n", "").replace(" ", " ")
@@ -321,12 +318,8 @@ def get_inference_model_by_args(args_to_parse):
321
  + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
322
  )[pred_mask]
323
 
324
- output_str = f"ASSITANT: {text_output}"
325
- if save_img is not None:
326
- output_image = save_img # input_image
327
- else:
328
- ## no seg output
329
- output_image = cv2.imread("./resources/no_seg_out.png")[:, :, ::-1]
330
  logging.info(f"output_image type: {type(output_image)}.")
331
  return output_image, output_str
332
 
 
20
  from model.llava import conversation as conversation_lib
21
  from model.llava.mm_utils import tokenizer_image_token
22
  from model.segment_anything.utils.transforms import ResizeLongestSide
23
+ from utils import constants, session_logger, utils
 
 
24
 
25
 
26
  session_logger.change_logging(logging.DEBUG)
 
32
  os.makedirs(FASTAPI_STATIC, exist_ok=True)
33
  app.mount("/static", StaticFiles(directory=FASTAPI_STATIC), name="static")
34
  templates = Jinja2Templates(directory="templates")
35
+ placeholders = utils.create_placeholder_variables()
36
 
37
 
38
  @app.get("/health")
 
229
  logging.info(f"args_to_parse:{args_to_parse}, creating model...")
230
  model, clip_image_processor, tokenizer, transform = get_model(args_to_parse)
231
  logging.info("created model, preparing inference function")
232
+ no_seg_out, error_happened = placeholders["no_seg_out"], placeholders["error_happened"]
233
 
234
  @session_logger.set_uuid_logging
235
  def inference(input_str, input_image):
 
242
  ## input valid check
243
  if not re.match(r"^[A-Za-z ,.!?\'\"]+$", input_str) or len(input_str) < 1:
244
  output_str = "[Error] Invalid input: ", input_str
245
+ return error_happened, output_str
 
 
 
246
 
247
  # Model Inference
248
  conv = conversation_lib.conv_templates[args_to_parse.conv_type].copy()
249
  conv.messages = []
250
 
251
  prompt = input_str
252
+ prompt = utils.DEFAULT_IMAGE_TOKEN + "\n" + prompt
253
  if args_to_parse.use_mm_start_end:
254
  replace_token = (
255
+ utils.DEFAULT_IM_START_TOKEN + utils.DEFAULT_IMAGE_TOKEN + utils.DEFAULT_IM_END_TOKEN
256
  )
257
+ prompt = prompt.replace(utils.DEFAULT_IMAGE_TOKEN, replace_token)
258
 
259
  conv.append_message(conv.roles[0], prompt)
260
  conv.append_message(conv.roles[1], "")
 
297
  max_new_tokens=512,
298
  tokenizer=tokenizer,
299
  )
300
+ output_ids = output_ids[0][output_ids[0] != utils.IMAGE_TOKEN_INDEX]
301
 
302
  text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
303
  text_output = text_output.replace("\n", "").replace(" ", " ")
 
318
  + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
319
  )[pred_mask]
320
 
321
+ output_str = f"ASSISTANT: {text_output}"
322
+ output_image = no_seg_out if save_img is None else save_img
 
 
 
 
323
  logging.info(f"output_image type: {type(output_image)}.")
324
  return output_image, output_str
325
 
resources/placeholders/error_happened.png ADDED

Git LFS Details

  • SHA256: f485f5a4e3df0bc33f6117a03d919d8043a70077863239bb969946e46d6f7349
  • Pointer size: 130 Bytes
  • Size of remote file: 37 kB
resources/placeholders/no_seg_out.png ADDED

Git LFS Details

  • SHA256: cccb555bff1ac91973741f77617dcf039c38ce090ad5685e16cb2536d1180774
  • Pointer size: 130 Bytes
  • Size of remote file: 38.7 kB
utils/utils.py CHANGED
@@ -1,9 +1,11 @@
1
  from enum import Enum
 
2
 
3
  import numpy as np
4
  import torch
5
  import torch.distributed as dist
6
 
 
7
  IGNORE_INDEX = -100
8
  IMAGE_TOKEN_INDEX = -200
9
  DEFAULT_IMAGE_TOKEN = "<image>"
@@ -40,6 +42,7 @@ ANSWER_LIST = [
40
  "Sure, the segmentation result is [SEG].",
41
  "[SEG].",
42
  ]
 
43
 
44
 
45
  class Summary(Enum):
@@ -161,3 +164,14 @@ def dict_to_cuda(input_dict):
161
  ):
162
  input_dict[k] = [ele.cuda(non_blocking=True) for ele in v]
163
  return input_dict
 
 
 
 
 
 
 
 
 
 
 
 
1
  from enum import Enum
2
+ from pathlib import Path
3
 
4
  import numpy as np
5
  import torch
6
  import torch.distributed as dist
7
 
8
+
9
  IGNORE_INDEX = -100
10
  IMAGE_TOKEN_INDEX = -200
11
  DEFAULT_IMAGE_TOKEN = "<image>"
 
42
  "Sure, the segmentation result is [SEG].",
43
  "[SEG].",
44
  ]
45
+ ROOT = Path(__file__).parent.parent
46
 
47
 
48
  class Summary(Enum):
 
164
  ):
165
  input_dict[k] = [ele.cuda(non_blocking=True) for ele in v]
166
  return input_dict
167
+
168
+
169
+ def create_placeholder_variables():
170
+ import cv2
171
+
172
+ no_seg_out = cv2.imread(str(ROOT / "resources" / "placeholder" / "no_seg_out.png"))[:, :, ::-1]
173
+ error_happened = cv2.imread(str(ROOT / "resources" / "placeholder" / "error_happened.png"))[:, :, ::-1]
174
+ return {
175
+ "no_seg_out": no_seg_out,
176
+ "error_happened": error_happened
177
+ }