alessandro trinca tornidor commited on
Commit
937bd43
·
1 Parent(s): dfbc77d

[refactor] start reducing complexity of chat.py

Browse files
Files changed (1) hide show
  1. app/chat.py +23 -76
app/chat.py CHANGED
@@ -1,70 +1,21 @@
1
- import argparse
2
  import os
3
  import sys
4
 
5
  import cv2
6
  import numpy as np
7
  import torch
8
- import torch.nn.functional as F
9
  from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
10
 
11
  from model.LISA import LISAForCausalLM
12
  from model.llava import conversation as conversation_lib
13
  from model.llava.mm_utils import tokenizer_image_token
14
  from model.segment_anything.utils.transforms import ResizeLongestSide
15
- from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
16
- DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
17
-
18
-
19
- def parse_args(args):
20
- parser = argparse.ArgumentParser(description="LISA chat")
21
- parser.add_argument("--version", default="xinlai/LISA-13B-llama2-v1")
22
- parser.add_argument("--vis_save_path", default="./vis_output", type=str)
23
- parser.add_argument(
24
- "--precision",
25
- default="bf16",
26
- type=str,
27
- choices=["fp32", "bf16", "fp16"],
28
- help="precision for inference",
29
- )
30
- parser.add_argument("--image_size", default=1024, type=int, help="image size")
31
- parser.add_argument("--model_max_length", default=512, type=int)
32
- parser.add_argument("--lora_r", default=8, type=int)
33
- parser.add_argument(
34
- "--vision-tower", default="openai/clip-vit-large-patch14", type=str
35
- )
36
- parser.add_argument("--local-rank", default=0, type=int, help="node rank")
37
- parser.add_argument("--load_in_8bit", action="store_true", default=False)
38
- parser.add_argument("--load_in_4bit", action="store_true", default=False)
39
- parser.add_argument("--use_mm_start_end", action="store_true", default=True)
40
- parser.add_argument(
41
- "--conv_type",
42
- default="llava_v1",
43
- type=str,
44
- choices=["llava_v1", "llava_llama_2"],
45
- )
46
- return parser.parse_args(args)
47
-
48
-
49
- def preprocess(
50
- x,
51
- pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
52
- pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
53
- img_size=1024,
54
- ) -> torch.Tensor:
55
- """Normalize pixel values and pad to a square input."""
56
- # Normalize colors
57
- x = (x - pixel_mean) / pixel_std
58
- # Pad
59
- h, w = x.shape[-2:]
60
- padh = img_size - h
61
- padw = img_size - w
62
- x = F.pad(x, (0, padw, 0, padh))
63
- return x
64
 
65
 
66
  def main(args):
67
- args = parse_args(args)
68
  os.makedirs(args.vis_save_path, exist_ok=True)
69
 
70
  # Create model
@@ -78,12 +29,7 @@ def main(args):
78
  tokenizer.pad_token = tokenizer.unk_token
79
  args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
80
 
81
-
82
- torch_dtype = torch.float32
83
- if args.precision == "bf16":
84
- torch_dtype = torch.bfloat16
85
- elif args.precision == "fp16":
86
- torch_dtype = torch.half
87
 
88
  kwargs = {"torch_dtype": torch_dtype}
89
  if args.load_in_4bit:
@@ -156,12 +102,12 @@ def main(args):
156
  conv.messages = []
157
 
158
  prompt = input("Please input your prompt: ")
159
- prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
160
  if args.use_mm_start_end:
161
  replace_token = (
162
- DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
163
  )
164
- prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
165
 
166
  conv.append_message(conv.roles[0], prompt)
167
  conv.append_message(conv.roles[1], "")
@@ -183,27 +129,19 @@ def main(args):
183
  .unsqueeze(0)
184
  .cuda()
185
  )
186
- if args.precision == "bf16":
187
- image_clip = image_clip.bfloat16()
188
- elif args.precision == "fp16":
189
- image_clip = image_clip.half()
190
- else:
191
- image_clip = image_clip.float()
192
 
193
  image = transform.apply_image(image_np)
194
  resize_list = [image.shape[:2]]
195
 
196
  image = (
197
- preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
198
  .unsqueeze(0)
199
  .cuda()
200
  )
201
- if args.precision == "bf16":
202
- image = image.bfloat16()
203
- elif args.precision == "fp16":
204
- image = image.half()
205
- else:
206
- image = image.float()
207
 
208
  input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
209
  input_ids = input_ids.unsqueeze(0).cuda()
@@ -217,11 +155,11 @@ def main(args):
217
  max_new_tokens=512,
218
  tokenizer=tokenizer,
219
  )
220
- output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX]
221
 
222
  text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
223
  text_output = text_output.replace("\n", "").replace(" ", " ")
224
- print("text_output: ", text_output)
225
 
226
  for i, pred_mask in enumerate(pred_masks):
227
  if pred_mask.shape[0] == 0:
@@ -249,5 +187,14 @@ def main(args):
249
  print("{} has been saved.".format(save_path))
250
 
251
 
 
 
 
 
 
 
 
 
 
252
  if __name__ == "__main__":
253
  main(sys.argv[1:])
 
1
+ import logging
2
  import os
3
  import sys
4
 
5
  import cv2
6
  import numpy as np
7
  import torch
 
8
  from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
9
 
10
  from model.LISA import LISAForCausalLM
11
  from model.llava import conversation as conversation_lib
12
  from model.llava.mm_utils import tokenizer_image_token
13
  from model.segment_anything.utils.transforms import ResizeLongestSide
14
+ from utils import app_helpers, utils
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  def main(args):
18
+ args = app_helpers.parse_args(args)
19
  os.makedirs(args.vis_save_path, exist_ok=True)
20
 
21
  # Create model
 
29
  tokenizer.pad_token = tokenizer.unk_token
30
  args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
31
 
32
+ torch_dtype = change_torch_dtype_by_precision(args.precision)
 
 
 
 
 
33
 
34
  kwargs = {"torch_dtype": torch_dtype}
35
  if args.load_in_4bit:
 
102
  conv.messages = []
103
 
104
  prompt = input("Please input your prompt: ")
105
+ prompt = utils.DEFAULT_IMAGE_TOKEN + "\n" + prompt
106
  if args.use_mm_start_end:
107
  replace_token = (
108
+ utils.DEFAULT_IM_START_TOKEN + utils.DEFAULT_IMAGE_TOKEN + utils.DEFAULT_IM_END_TOKEN
109
  )
110
+ prompt = prompt.replace(utils.DEFAULT_IMAGE_TOKEN, replace_token)
111
 
112
  conv.append_message(conv.roles[0], prompt)
113
  conv.append_message(conv.roles[1], "")
 
129
  .unsqueeze(0)
130
  .cuda()
131
  )
132
+ logging.info(f"image_clip type: {type(image_clip)}.")
133
+ image_clip = app_helpers.set_image_precision_by_args(image_clip, args.precision)
 
 
 
 
134
 
135
  image = transform.apply_image(image_np)
136
  resize_list = [image.shape[:2]]
137
 
138
  image = (
139
+ app_helpers.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
140
  .unsqueeze(0)
141
  .cuda()
142
  )
143
+ logging.info(f"image_clip type: {type(image_clip)}.")
144
+ image = app_helpers.set_image_precision_by_args(image, args.precision)
 
 
 
 
145
 
146
  input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
147
  input_ids = input_ids.unsqueeze(0).cuda()
 
155
  max_new_tokens=512,
156
  tokenizer=tokenizer,
157
  )
158
+ output_ids = output_ids[0][output_ids[0] != utils.IMAGE_TOKEN_INDEX]
159
 
160
  text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
161
  text_output = text_output.replace("\n", "").replace(" ", " ")
162
+ logging.info(f"text_output: {text_output}.")
163
 
164
  for i, pred_mask in enumerate(pred_masks):
165
  if pred_mask.shape[0] == 0:
 
187
  print("{} has been saved.".format(save_path))
188
 
189
 
190
+ def change_torch_dtype_by_precision(precision):
191
+ torch_dtype = torch.float32
192
+ if precision == "bf16":
193
+ torch_dtype = torch.bfloat16
194
+ elif precision == "fp16":
195
+ torch_dtype = torch.half
196
+ return torch_dtype
197
+
198
+
199
  if __name__ == "__main__":
200
  main(sys.argv[1:])