SNUMPR commited on
Commit
598d165
·
verified ·
1 Parent(s): 8a69306

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/sample_demo_1.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ examples/sample_demo_13.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ examples/sample_demo_22.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ examples/sample_demo_8.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ examples/sample_img_8.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Vlm Rlaif Demo
3
- emoji: 💻
4
- colorFrom: red
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 4.40.0
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: vlm-rlaif-demo
3
+ app_file: gradio_web_server.py
 
 
4
  sdk: gradio
5
+ sdk_version: 3.35.2
 
 
6
  ---
 
 
__init__.py ADDED
File without changes
__pycache__/__init__.cpython-310.pyc ADDED
Binary file (134 Bytes). View file
 
__pycache__/gradio_utils.cpython-310.pyc ADDED
Binary file (5.63 kB). View file
 
__pycache__/gradio_web_server.cpython-310.pyc ADDED
Binary file (5.91 kB). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (603 Bytes). View file
 
asset/Model.png ADDED
cli.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+
6
+ import sys
7
+ sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "Evaluation"))
8
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, \
9
+ DEFAULT_VIDEO_TOKEN
10
+ from llava.conversation import conv_templates, SeparatorStyle
11
+ from llava.model.builder import load_pretrained_model
12
+ from llava.utils import disable_torch_init
13
+ from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
14
+ from serve.utils import load_image, image_ext, video_ext
15
+
16
+ from PIL import Image
17
+
18
+ import requests
19
+ from PIL import Image
20
+ from io import BytesIO
21
+ from transformers import TextStreamer
22
+
23
+
24
+
25
+ def main(args):
26
+ # Model
27
+ disable_torch_init()
28
+
29
+ model_name = get_model_name_from_path(args.model_path)
30
+ tokenizer, model, processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name,
31
+ args.load_8bit, args.load_4bit,
32
+ device=args.device, cache_dir=args.cache_dir)
33
+ image_processor, video_processor = processor['image'], processor['video']
34
+ if 'llama-2' in model_name.lower():
35
+ conv_mode = "llava_llama_2"
36
+ elif "v1" in model_name.lower():
37
+ conv_mode = "llava_v1"
38
+ elif "mpt" in model_name.lower():
39
+ conv_mode = "mpt"
40
+ else:
41
+ conv_mode = "llava_v0"
42
+
43
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
44
+ print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
45
+ else:
46
+ args.conv_mode = conv_mode
47
+
48
+ conv = conv_templates[args.conv_mode].copy()
49
+ if "mpt" in model_name.lower():
50
+ roles = ('user', 'assistant')
51
+ else:
52
+ roles = conv.roles
53
+
54
+ tensor = []
55
+ special_token = []
56
+ args.file = args.file if isinstance(args.file, list) else [args.file]
57
+ for file in args.file:
58
+ if os.path.splitext(file)[-1].lower() in video_ext: # video extension
59
+ video_tensor = video_processor(file, return_tensors='pt')['pixel_values'][0].to(model.device, dtype=torch.float16)
60
+ special_token += [DEFAULT_IMAGE_TOKEN] * model.get_video_tower().config.num_frames
61
+ elif os.path.splitext(os.listdir(file)[0]).lower() in image_ext: # frames folder
62
+ vidframes_list = sorted(glob(file + '/*'))
63
+ images = load_frames(vidframes_list, model.get_video_tower().config.num_frames)
64
+ # Similar operation in model_worker.py
65
+ video_tensor = process_images(images, image_processor, args)
66
+ video_tensor = video_tensor.to(model.device, dtype=torch.float16)
67
+ video_tensor = video_tensor.unsqueeze(0)
68
+ special_token += [DEFAULT_IMAGE_TOKEN] * model.get_video_tower().config.num_frames
69
+ else:
70
+ raise ValueError(f'Support video of {video_ext} and frames of {image_ext}, but found {os.path.splitext(file)[-1].lower()}')
71
+ print(video_tensor.shape)
72
+ tensor.append(video_tensor)
73
+
74
+
75
+
76
+
77
+ while True:
78
+ try:
79
+ inp = input(f"{roles[0]}: ")
80
+ except EOFError:
81
+ inp = ""
82
+ if not inp:
83
+ print("exit...")
84
+ break
85
+
86
+ print(f"{roles[1]}: ", end="")
87
+
88
+ if file is not None:
89
+ # first message
90
+ if getattr(model.config, "mm_use_im_start_end", False):
91
+ inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
92
+ # inp = ''.join([DEFAULT_IM_START_TOKEN + i + DEFAULT_IM_END_TOKEN for i in special_token]) + '\n' + inp
93
+ else:
94
+ inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
95
+ # inp = ''.join(special_token) + '\n' + inp
96
+ conv.append_message(conv.roles[0], inp)
97
+ file = None
98
+ else:
99
+ # later messages
100
+ conv.append_message(conv.roles[0], inp)
101
+ conv.append_message(conv.roles[1], None)
102
+ prompt = conv.get_prompt()
103
+
104
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
105
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
106
+ keywords = [stop_str]
107
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
108
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
109
+
110
+ with torch.inference_mode():
111
+ output_ids = model.generate(
112
+ input_ids,
113
+ images=tensor, # video as fake images
114
+ do_sample=True if args.temperature > 0 else False,
115
+ temperature=args.temperature,
116
+ max_new_tokens=args.max_new_tokens,
117
+ streamer=streamer,
118
+ use_cache=True,
119
+ stopping_criteria=[stopping_criteria])
120
+
121
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
122
+ conv.messages[-1][-1] = outputs
123
+
124
+ if args.debug:
125
+ print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
126
+
127
+
128
+ if __name__ == "__main__":
129
+ parser = argparse.ArgumentParser()
130
+ parser.add_argument("--model-path", type=str, default="LanguageBind/Video-LLaVA-7B")
131
+ parser.add_argument("--model-base", type=str, default=None)
132
+ parser.add_argument("--cache-dir", type=str, default=None)
133
+ parser.add_argument("--file", nargs='+', type=str, required=True)
134
+ parser.add_argument("--device", type=str, default="cuda")
135
+ parser.add_argument("--conv-mode", type=str, default=None)
136
+ parser.add_argument("--temperature", type=float, default=0.2)
137
+ parser.add_argument("--max-new-tokens", type=int, default=512)
138
+ parser.add_argument("--load-8bit", action="store_true")
139
+ parser.add_argument("--load-4bit", action="store_true")
140
+ parser.add_argument("--debug", action="store_true")
141
+ args = parser.parse_args()
142
+ main(args)
controller.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A controller manages distributed workers.
3
+ It sends worker addresses to clients.
4
+ """
5
+ import argparse
6
+ import asyncio
7
+ import dataclasses
8
+ from enum import Enum, auto
9
+ import json
10
+ import logging
11
+ import time
12
+ from typing import List, Union
13
+ import threading
14
+
15
+ from fastapi import FastAPI, Request
16
+ from fastapi.responses import StreamingResponse
17
+ import numpy as np
18
+ import requests
19
+ import uvicorn
20
+
21
+ from videollava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
22
+ from videollava.utils import build_logger, server_error_msg
23
+
24
+
25
+ logger = build_logger("controller", "controller.log")
26
+
27
+
28
+ class DispatchMethod(Enum):
29
+ LOTTERY = auto()
30
+ SHORTEST_QUEUE = auto()
31
+
32
+ @classmethod
33
+ def from_str(cls, name):
34
+ if name == "lottery":
35
+ return cls.LOTTERY
36
+ elif name == "shortest_queue":
37
+ return cls.SHORTEST_QUEUE
38
+ else:
39
+ raise ValueError(f"Invalid dispatch method")
40
+
41
+
42
+ @dataclasses.dataclass
43
+ class WorkerInfo:
44
+ model_names: List[str]
45
+ speed: int
46
+ queue_length: int
47
+ check_heart_beat: bool
48
+ last_heart_beat: str
49
+
50
+
51
+ def heart_beat_controller(controller):
52
+ while True:
53
+ time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
54
+ controller.remove_stable_workers_by_expiration()
55
+
56
+
57
+ class Controller:
58
+ def __init__(self, dispatch_method: str):
59
+ # Dict[str -> WorkerInfo]
60
+ self.worker_info = {}
61
+ self.dispatch_method = DispatchMethod.from_str(dispatch_method)
62
+
63
+ self.heart_beat_thread = threading.Thread(
64
+ target=heart_beat_controller, args=(self,))
65
+ self.heart_beat_thread.start()
66
+
67
+ logger.info("Init controller")
68
+
69
+ def register_worker(self, worker_name: str, check_heart_beat: bool,
70
+ worker_status: dict):
71
+ if worker_name not in self.worker_info:
72
+ logger.info(f"Register a new worker: {worker_name}")
73
+ else:
74
+ logger.info(f"Register an existing worker: {worker_name}")
75
+
76
+ if not worker_status:
77
+ worker_status = self.get_worker_status(worker_name)
78
+ if not worker_status:
79
+ return False
80
+
81
+ self.worker_info[worker_name] = WorkerInfo(
82
+ worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
83
+ check_heart_beat, time.time())
84
+
85
+ logger.info(f"Register done: {worker_name}, {worker_status}")
86
+ return True
87
+
88
+ def get_worker_status(self, worker_name: str):
89
+ try:
90
+ r = requests.post(worker_name + "/worker_get_status", timeout=5)
91
+ except requests.exceptions.RequestException as e:
92
+ logger.error(f"Get status fails: {worker_name}, {e}")
93
+ return None
94
+
95
+ if r.status_code != 200:
96
+ logger.error(f"Get status fails: {worker_name}, {r}")
97
+ return None
98
+
99
+ return r.json()
100
+
101
+ def remove_worker(self, worker_name: str):
102
+ del self.worker_info[worker_name]
103
+
104
+ def refresh_all_workers(self):
105
+ old_info = dict(self.worker_info)
106
+ self.worker_info = {}
107
+
108
+ for w_name, w_info in old_info.items():
109
+ if not self.register_worker(w_name, w_info.check_heart_beat, None):
110
+ logger.info(f"Remove stale worker: {w_name}")
111
+
112
+ def list_models(self):
113
+ model_names = set()
114
+
115
+ for w_name, w_info in self.worker_info.items():
116
+ model_names.update(w_info.model_names)
117
+
118
+ return list(model_names)
119
+
120
+ def get_worker_address(self, model_name: str):
121
+ if self.dispatch_method == DispatchMethod.LOTTERY:
122
+ worker_names = []
123
+ worker_speeds = []
124
+ for w_name, w_info in self.worker_info.items():
125
+ if model_name in w_info.model_names:
126
+ worker_names.append(w_name)
127
+ worker_speeds.append(w_info.speed)
128
+ worker_speeds = np.array(worker_speeds, dtype=np.float32)
129
+ norm = np.sum(worker_speeds)
130
+ if norm < 1e-4:
131
+ return ""
132
+ worker_speeds = worker_speeds / norm
133
+ if True: # Directly return address
134
+ pt = np.random.choice(np.arange(len(worker_names)),
135
+ p=worker_speeds)
136
+ worker_name = worker_names[pt]
137
+ return worker_name
138
+
139
+ # Check status before returning
140
+ while True:
141
+ pt = np.random.choice(np.arange(len(worker_names)),
142
+ p=worker_speeds)
143
+ worker_name = worker_names[pt]
144
+
145
+ if self.get_worker_status(worker_name):
146
+ break
147
+ else:
148
+ self.remove_worker(worker_name)
149
+ worker_speeds[pt] = 0
150
+ norm = np.sum(worker_speeds)
151
+ if norm < 1e-4:
152
+ return ""
153
+ worker_speeds = worker_speeds / norm
154
+ continue
155
+ return worker_name
156
+ elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
157
+ worker_names = []
158
+ worker_qlen = []
159
+ for w_name, w_info in self.worker_info.items():
160
+ if model_name in w_info.model_names:
161
+ worker_names.append(w_name)
162
+ worker_qlen.append(w_info.queue_length / w_info.speed)
163
+ if len(worker_names) == 0:
164
+ return ""
165
+ min_index = np.argmin(worker_qlen)
166
+ w_name = worker_names[min_index]
167
+ self.worker_info[w_name].queue_length += 1
168
+ logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
169
+ return w_name
170
+ else:
171
+ raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
172
+
173
+ def receive_heart_beat(self, worker_name: str, queue_length: int):
174
+ if worker_name not in self.worker_info:
175
+ logger.info(f"Receive unknown heart beat. {worker_name}")
176
+ return False
177
+
178
+ self.worker_info[worker_name].queue_length = queue_length
179
+ self.worker_info[worker_name].last_heart_beat = time.time()
180
+ logger.info(f"Receive heart beat. {worker_name}")
181
+ return True
182
+
183
+ def remove_stable_workers_by_expiration(self):
184
+ expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
185
+ to_delete = []
186
+ for worker_name, w_info in self.worker_info.items():
187
+ if w_info.check_heart_beat and w_info.last_heart_beat < expire:
188
+ to_delete.append(worker_name)
189
+
190
+ for worker_name in to_delete:
191
+ self.remove_worker(worker_name)
192
+
193
+ def worker_api_generate_stream(self, params):
194
+ worker_addr = self.get_worker_address(params["model"])
195
+ if not worker_addr:
196
+ logger.info(f"no worker: {params['model']}")
197
+ ret = {
198
+ "text": server_error_msg,
199
+ "error_code": 2,
200
+ }
201
+ yield json.dumps(ret).encode() + b"\0"
202
+
203
+ try:
204
+ response = requests.post(worker_addr + "/worker_generate_stream",
205
+ json=params, stream=True, timeout=5)
206
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
207
+ if chunk:
208
+ yield chunk + b"\0"
209
+ except requests.exceptions.RequestException as e:
210
+ logger.info(f"worker timeout: {worker_addr}")
211
+ ret = {
212
+ "text": server_error_msg,
213
+ "error_code": 3,
214
+ }
215
+ yield json.dumps(ret).encode() + b"\0"
216
+
217
+
218
+ # Let the controller act as a worker to achieve hierarchical
219
+ # management. This can be used to connect isolated sub networks.
220
+ def worker_api_get_status(self):
221
+ model_names = set()
222
+ speed = 0
223
+ queue_length = 0
224
+
225
+ for w_name in self.worker_info:
226
+ worker_status = self.get_worker_status(w_name)
227
+ if worker_status is not None:
228
+ model_names.update(worker_status["model_names"])
229
+ speed += worker_status["speed"]
230
+ queue_length += worker_status["queue_length"]
231
+
232
+ return {
233
+ "model_names": list(model_names),
234
+ "speed": speed,
235
+ "queue_length": queue_length,
236
+ }
237
+
238
+
239
+ app = FastAPI()
240
+
241
+
242
+ @app.post("/register_worker")
243
+ async def register_worker(request: Request):
244
+ data = await request.json()
245
+ controller.register_worker(
246
+ data["worker_name"], data["check_heart_beat"],
247
+ data.get("worker_status", None))
248
+
249
+
250
+ @app.post("/refresh_all_workers")
251
+ async def refresh_all_workers():
252
+ models = controller.refresh_all_workers()
253
+
254
+
255
+ @app.post("/list_models")
256
+ async def list_models():
257
+ models = controller.list_models()
258
+ return {"models": models}
259
+
260
+
261
+ @app.post("/get_worker_address")
262
+ async def get_worker_address(request: Request):
263
+ data = await request.json()
264
+ addr = controller.get_worker_address(data["model"])
265
+ return {"address": addr}
266
+
267
+
268
+ @app.post("/receive_heart_beat")
269
+ async def receive_heart_beat(request: Request):
270
+ data = await request.json()
271
+ exist = controller.receive_heart_beat(
272
+ data["worker_name"], data["queue_length"])
273
+ return {"exist": exist}
274
+
275
+
276
+ @app.post("/worker_generate_stream")
277
+ async def worker_api_generate_stream(request: Request):
278
+ params = await request.json()
279
+ generator = controller.worker_api_generate_stream(params)
280
+ return StreamingResponse(generator)
281
+
282
+
283
+ @app.post("/worker_get_status")
284
+ async def worker_api_get_status(request: Request):
285
+ return controller.worker_api_get_status()
286
+
287
+
288
+ if __name__ == "__main__":
289
+ parser = argparse.ArgumentParser()
290
+ parser.add_argument("--host", type=str, default="localhost")
291
+ parser.add_argument("--port", type=int, default=21001)
292
+ parser.add_argument("--dispatch-method", type=str, choices=[
293
+ "lottery", "shortest_queue"], default="shortest_queue")
294
+ args = parser.parse_args()
295
+ logger.info(f"args: {args}")
296
+
297
+ controller = Controller(args.dispatch_method)
298
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
examples/desert.jpg ADDED
examples/extreme_ironing.jpg ADDED
examples/sample_demo_1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc6562a172eb9cb3c760a3c9992349c1faa2c793c112b7b9e50bd5cb17c2164d
3
+ size 1549315
examples/sample_demo_13.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13384915331bf749fa31e2f4cbbd85ca90439b81b2390b4b512bd24b0dbd8bae
3
+ size 19356822
examples/sample_demo_22.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dcde24b3e67ff23aafd4b69854dbc7e2485eae65999c86c1beb9160d53fa2a11
3
+ size 1505931
examples/sample_demo_3.mp4 ADDED
Binary file (464 kB). View file
 
examples/sample_demo_8.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:618bb02562c769303b797ae3c29a66e15dcc0134d673747e8cf90582369c59a2
3
+ size 29771700
examples/sample_demo_9.mp4 ADDED
Binary file (632 kB). View file
 
examples/sample_img_13.png ADDED
examples/sample_img_22.png ADDED
examples/sample_img_8.png ADDED

Git LFS Details

  • SHA256: 4455fa94baf3f7dcbc9e547adb2ab98cbaf5671922d4fac297feed270eef4dd1
  • Pointer size: 132 Bytes
  • Size of remote file: 5.2 MB
examples/waterview.jpg ADDED
gradio_utils.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import TextStreamer
3
+
4
+ import os
5
+ import sys
6
+ sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "Evaluation"))
7
+ from llava.constants import IMAGE_TOKEN_INDEX
8
+ from llava.conversation import conv_templates, SeparatorStyle
9
+ from llava.mm_utils import get_model_name_from_path, KeywordsStoppingCriteria, tokenizer_image_token
10
+ from llava.model.builder import load_pretrained_model
11
+ from llava.utils import disable_torch_init
12
+ import shutil
13
+
14
+ # <a href="https://github.com/SNUMPR/vlm-rlaif.git" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;">
15
+ # <img src="https://z1.ax1x.com/2023/11/07/pil4sqH.png" alt="VLM-RLAIF" style="max-width: 120px; height: auto;">
16
+ # </a>
17
+
18
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
19
+ title_markdown = ("""
20
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
21
+ <img src="/dataset/dcahn/yura/vlm-rlaif/asset/Model.png" alt="VLM-RLAIF" style="max-width: 120px; height: auto;">
22
+ <img src="file:/dataset/dcahn/yura/vlm-rlaif/asset/Model.png" alt="VLM-RLAIF" style="max-width: 120px; height: auto;">
23
+ <div>
24
+ <h1 >VLM-RLAIF: Tuning Large Multimodal Models for Videos using Reinforcement Learning from AI Feedback (ACL 2024 Oral) </h1>
25
+ <h5 style="margin: 0;">If you like our project, please give us a star ✨ on Github for the latest update.</h5>
26
+ </div>
27
+ </div>
28
+
29
+
30
+ <div align="center">
31
+ <div style="display:flex; gap: 0.25rem;" align="center">
32
+ <a href='https://github.com/SNUMPR/vlm-rlaif'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
33
+ <a href="https://arxiv.org/abs/2402.03746"><img src="https://img.shields.io/badge/Paper-arxiv-green"></a>
34
+ </div>
35
+ </div>
36
+ """)
37
+ # <a href='https://github.com/PKU-YuanGroup/Video-LLaVA/stargazers'><img src='https://img.shields.io/github/stars/PKU-YuanGroup/Video-LLaVA.svg?style=social'></a> # arXiv 버튼 옆에 추가?
38
+
39
+ block_css = """
40
+ #buttons button {
41
+ min-width: min(120px,100%);
42
+ }
43
+ """
44
+
45
+ tos_markdown = ("""
46
+ ### Terms of use
47
+ By using this service, users are required to agree to the following terms:
48
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
49
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
50
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
51
+ """)
52
+
53
+ learn_more_markdown = ("""
54
+ ### License
55
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
56
+ """)
57
+
58
+
59
+ class Chat:
60
+ def __init__(self, model_path, conv_mode, model_base=None, load_8bit=False, load_4bit=False, device='cuda', cache_dir=None):
61
+ # model_base = '/dataset/yura/vlm-rlaif/pretrained/final_models/Video_LLaVA_SFT'
62
+ # model_base='/dataset/yura/vlm-rlaif/pretrained/llava-v1.5-7b-lora_w_lora_16_sftv2_short1632_and_then_long_rank32_alpha32_lr1e4_allmodels/SFT_merged'
63
+ # model_path = '/dataset/yura/vlm-rlaif/pretrained/LLaVA_Video-RL-Fact-RLHF-7b_SFTv2_RM_13b_v1_40k-v1.5-336-lora-padding/checkpoint-180/adapter_model/lora_policy'
64
+
65
+ disable_torch_init()
66
+ model_name = get_model_name_from_path(model_path)
67
+ # self.tokenizer, self.model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name,
68
+ # load_8bit, load_4bit,
69
+ # device=device, cache_dir=cache_dir)
70
+ is_rlhf_checkpoint = 'rlhf' in model_path.lower()
71
+ print("MODEL_PATH", model_path)
72
+ print("RLHF Checkpoint: ", is_rlhf_checkpoint)
73
+ if not model_base or model_base == "none": model_base = None
74
+ if is_rlhf_checkpoint:
75
+ model_name = model_path
76
+ print("Config?", os.path.exists(os.path.join(model_path, "config.json")))
77
+ if not os.path.exists(os.path.join(model_path, "config.json")):
78
+ print("Copying")
79
+ shutil.copy(os.path.join(model_base, "config.json"), os.path.join(model_path, "config.json")) # Copy SFT model's config -> to RLHF folder
80
+ print("Listed", os.listdir(model_path))
81
+ print("Copying done")
82
+ # return(model_name)
83
+ # return
84
+ # self.tokenizer, self.model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name, load_8bit, load_4bit, device=device)
85
+ self.tokenizer, self.model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name, False, False, device=device)
86
+
87
+
88
+
89
+ self.image_processor = image_processor
90
+ # self.image_processor = processor['image']
91
+ # self.video_processor = processor['video']
92
+ self.conv_mode = conv_mode
93
+ self.conv = conv_templates[conv_mode].copy()
94
+ self.device = self.model.device
95
+ print(self.model)
96
+
97
+ def get_prompt(self, qs, state):
98
+ state.append_message(state.roles[0], qs)
99
+ state.append_message(state.roles[1], None)
100
+ return state
101
+
102
+ def _get_latest_prompt(self, state):
103
+ new_state = state.copy()
104
+ new_state.messages = state.messages[-2:]
105
+ return new_state
106
+
107
+ @torch.inference_mode()
108
+ # def generate(self, images_tensor: list, prompt: str, first_run: bool, state):
109
+ def generate(self, images_tensor: torch.Tensor, prompt: str, first_run: bool, state):
110
+ tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
111
+
112
+ state = self.get_prompt(prompt, state)
113
+ # prompt = state.get_prompt()
114
+ latest_state = self._get_latest_prompt(state)
115
+ prompt = latest_state.get_prompt()
116
+
117
+ # print('\n\n\n')
118
+ # print(prompt)
119
+
120
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
121
+
122
+ temperature = 0.2
123
+
124
+ max_new_tokens = 1024
125
+
126
+ stop_str = self.conv.sep if self.conv.sep_style != SeparatorStyle.TWO else self.conv.sep2
127
+ keywords = [stop_str]
128
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
129
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
130
+ print(prompt, input_ids.shape, images_tensor.shape)
131
+ # print(images_tensor)
132
+ with torch.inference_mode():
133
+ output_ids = model.generate(
134
+ input_ids,
135
+ images=images_tensor,
136
+ do_sample=True,
137
+ temperature=temperature,
138
+ max_new_tokens=max_new_tokens,
139
+ streamer=streamer,
140
+ use_cache=True,
141
+ stopping_criteria=[stopping_criteria])
142
+
143
+ input_token_len = input_ids.shape[1]
144
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
145
+ if n_diff_input_output > 0:
146
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
147
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
148
+ outputs = outputs.strip()
149
+ outputs = outputs.replace("QA_GT_caption_based_noisy", "")
150
+ if outputs.endswith(stop_str):
151
+ outputs = outputs[:-len(stop_str)]
152
+ outputs = outputs.strip()
153
+
154
+ print('response', outputs)
155
+ return outputs, state
gradio_web_server copy.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import subprocess
3
+
4
+ import torch
5
+ import gradio as gr
6
+ from fastapi import FastAPI
7
+ import os
8
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
9
+ from PIL import Image
10
+ import tempfile
11
+ from decord import VideoReader, cpu
12
+ from transformers import TextStreamer
13
+ import argparse
14
+
15
+ import sys
16
+ sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "Evaluation"))
17
+ from llava.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
18
+ from llava.conversation import conv_templates, SeparatorStyle, Conversation
19
+ from llava.mm_utils import process_images
20
+
21
+ from Evaluation.infer_utils import load_video_into_frames
22
+ from serve.utils import load_image, image_ext, video_ext
23
+ from serve.gradio_utils import Chat, tos_markdown, learn_more_markdown, title_markdown, block_css
24
+
25
+
26
+
27
+ def save_image_to_local(image):
28
+ filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg')
29
+ image = Image.open(image)
30
+ image.save(filename)
31
+ # print(filename)
32
+ return filename
33
+
34
+
35
+ def save_video_to_local(video_path):
36
+ filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4')
37
+ shutil.copyfile(video_path, filename)
38
+ return filename
39
+
40
+
41
+ def generate(image1, video, textbox_in, first_run, state, state_, images_tensor, num_frames=50):
42
+ # ======= manually clear the conversation
43
+ # state = conv_templates[conv_mode].copy()
44
+ # state_ = conv_templates[conv_mode].copy()
45
+ # # =======
46
+ flag = 1
47
+ if not textbox_in:
48
+ if len(state_.messages) > 0:
49
+ textbox_in = state_.messages[-1][1]
50
+ state_.messages.pop(-1)
51
+ flag = 0
52
+ else:
53
+ return "Please enter instruction"
54
+ print("Video", video) # 잘 들어감
55
+ print("Images_tensor", images_tensor) # None
56
+ print("Textbox_IN", textbox_in) # 잘 들어감
57
+ print("State", state) # None
58
+ print("State_", state_) # None
59
+ # print(len(state_.messages))
60
+
61
+ video = video if video else "none"
62
+
63
+ if type(state) is not Conversation:
64
+ state = conv_templates[conv_mode].copy()
65
+ state_ = conv_templates[conv_mode].copy()
66
+ images_tensor = []
67
+
68
+ first_run = False if len(state.messages) > 0 else True
69
+
70
+ text_en_in = textbox_in.replace("picture", "image")
71
+
72
+ image_processor = handler.image_processor
73
+ assert os.path.exists(video)
74
+ if os.path.splitext(video)[-1].lower() in video_ext: # video extension
75
+ video_decode_backend = 'opencv'
76
+ elif os.path.splitext(os.listdir(video)[0]).lower() in image_ext: # frames folder
77
+ video_decode_backend = 'frames'
78
+ else:
79
+ raise ValueError(f'Support video of {video_ext} and frames of {image_ext}, but found {os.path.splitext(video)[-1].lower()}')
80
+
81
+ frames = load_video_into_frames(video, video_decode_backend=video_decode_backend, num_frames=num_frames)
82
+ tensor = process_images(frames, image_processor, argparse.Namespace(image_aspect_ratio='pad'))
83
+ # tensor = video_processor(video, return_tensors='pt')['pixel_values'][0]
84
+ # print(tensor.shape)
85
+ tensor = tensor.to(handler.model.device, dtype=dtype)
86
+ # images_tensor.append(tensor)
87
+ images_tensor = tensor
88
+
89
+ if handler.model.config.mm_use_im_start_end:
90
+ text_en_in = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + text_en_in
91
+ else:
92
+ text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in
93
+ text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_)
94
+ state_.messages[-1] = (state_.roles[1], text_en_out)
95
+
96
+ text_en_out = text_en_out.split('#')[0]
97
+ textbox_out = text_en_out
98
+
99
+ show_images = ""
100
+ if os.path.exists(video):
101
+ filename = save_video_to_local(video)
102
+ show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={filename}"></video>'
103
+ if flag:
104
+ state.append_message(state.roles[0], textbox_in + "\n" + show_images)
105
+ state.append_message(state.roles[1], textbox_out)
106
+
107
+ return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=image1 if os.path.exists(video) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True))
108
+
109
+
110
+ def regenerate(state, state_):
111
+ state.messages.pop(-1)
112
+ state_.messages.pop(-1)
113
+ if len(state.messages) > 0:
114
+ return state, state_, state.to_gradio_chatbot(), False
115
+ return (state, state_, state.to_gradio_chatbot(), True)
116
+
117
+
118
+ def clear_history(state, state_):
119
+ state = conv_templates[conv_mode].copy()
120
+ state_ = conv_templates[conv_mode].copy()
121
+ return (gr.update(value=None, interactive=True),
122
+ gr.update(value=None, interactive=True), \
123
+ gr.update(value=None, interactive=True), \
124
+ True, state, state_, state.to_gradio_chatbot(), [])
125
+
126
+
127
+ # ==== CHANGE HERE ====
128
+ # conv_mode = "llava_v1"
129
+ # model_path = 'LanguageBind/Video-LLaVA-7B'
130
+ # FIXME!!!
131
+
132
+ conv_mode = "llava_v0"
133
+ model_path = 'SNUMPR/vlm_rlaif_video_llava_7b'
134
+ # model_path = '/dataset/yura/vlm-rlaif/pretrained/final_models/Video_LLaVA_VLM_RLAIF_merged'
135
+ cache_dir = './cache_dir'
136
+ device = 'cuda'
137
+ # device = 'cpu'
138
+ load_8bit = True
139
+ load_4bit = False
140
+ dtype = torch.float16
141
+ # =============
142
+
143
+ handler = Chat(model_path, conv_mode=conv_mode, load_8bit=load_8bit, load_4bit=load_8bit, device=device, cache_dir=cache_dir)
144
+ # handler.model.to(dtype=dtype)
145
+ if not os.path.exists("temp"):
146
+ os.makedirs("temp")
147
+
148
+ app = FastAPI()
149
+
150
+
151
+ textbox = gr.Textbox(
152
+ show_label=False, placeholder="Enter text and press ENTER", container=False
153
+ )
154
+ with gr.Blocks(title='VLM-RLAIF', theme=gr.themes.Default(), css=block_css) as demo:
155
+ gr.Markdown(title_markdown)
156
+ state = gr.State()
157
+ state_ = gr.State()
158
+ first_run = gr.State()
159
+ images_tensor = gr.State()
160
+
161
+ image1 = gr.Image(label="Input Image", type="filepath")
162
+ with gr.Row():
163
+ with gr.Column(scale=3):
164
+ video = gr.Video(label="Input Video")
165
+
166
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
167
+ gr.Examples(
168
+ examples=[
169
+ [
170
+ f"{cur_dir}/examples/sample_demo_1.mp4",
171
+ "Why is this video funny?",
172
+ ],
173
+ [
174
+ f"{cur_dir}/examples/sample_demo_3.mp4",
175
+ "Can you identify any safety hazards in this video?"
176
+ ],
177
+ [
178
+ f"{cur_dir}/examples/sample_demo_9.mp4",
179
+ "Describe the video.",
180
+ ],
181
+ [
182
+ f"{cur_dir}/examples/sample_demo_22.mp4",
183
+ "Describe the activity in the video.",
184
+ ],
185
+ ],
186
+ inputs=[video, textbox],
187
+ )
188
+
189
+ with gr.Column(scale=7):
190
+ chatbot = gr.Chatbot(label="VLM_RLAIF", bubble_full_width=True).style(height=750)
191
+ with gr.Row():
192
+ with gr.Column(scale=8):
193
+ textbox.render()
194
+ with gr.Column(scale=1, min_width=50):
195
+ submit_btn = gr.Button(
196
+ value="Send", variant="primary", interactive=True
197
+ )
198
+ with gr.Row(elem_id="buttons") as button_row:
199
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
200
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=True)
201
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
202
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
203
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
204
+ # clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
205
+
206
+ gr.Markdown(tos_markdown)
207
+ gr.Markdown(learn_more_markdown)
208
+
209
+ submit_btn.click(generate, [image1, video, textbox, first_run, state, state_, images_tensor],
210
+ [state, state_, chatbot, first_run, textbox, images_tensor, image1, video])
211
+ # submit_btn.click(generate, [video, textbox, first_run, state, state_, images_tensor],
212
+ # [state, state_, chatbot, first_run, textbox, images_tensor, video])
213
+
214
+ regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then(
215
+ generate, [image1, video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, image1, video])
216
+ # generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video])
217
+
218
+ # clear_btn.click(clear_history, [state, state_],
219
+ # [image1, video, textbox, first_run, state, state_, chatbot, images_tensor])
220
+ # [video, textbox, first_run, state, state_, chatbot, images_tensor])
221
+
222
+ # app = gr.mount_gradio_app(app, demo, path="/")
223
+ # demo.launch(share=True)
224
+ demo.launch()
225
+
226
+ # uvicorn videollava.serve.gradio_web_server:app
227
+ # python -m videollava.serve.gradio_web_server
gradio_web_server.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import subprocess
3
+
4
+ import torch
5
+ import gradio as gr
6
+ from fastapi import FastAPI
7
+ import os
8
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
9
+ from PIL import Image
10
+ import tempfile
11
+ from decord import VideoReader, cpu
12
+ from transformers import TextStreamer
13
+ import argparse
14
+
15
+ import sys
16
+ sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "Evaluation"))
17
+ from llava.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
18
+ from llava.conversation import conv_templates, SeparatorStyle, Conversation
19
+ from llava.mm_utils import process_images
20
+
21
+ from Evaluation.infer_utils import load_video_into_frames
22
+ from serve.utils import load_image, image_ext, video_ext
23
+ from serve.gradio_utils import Chat, tos_markdown, learn_more_markdown, title_markdown, block_css
24
+
25
+
26
+
27
+ def save_image_to_local(image):
28
+ filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg')
29
+ image = Image.open(image)
30
+ image.save(filename)
31
+ # print(filename)
32
+ return filename
33
+
34
+
35
+ def save_video_to_local(video_path):
36
+ filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4')
37
+ shutil.copyfile(video_path, filename)
38
+ return filename
39
+
40
+
41
+ def generate(video, textbox_in, first_run, state, state_, images_tensor, num_frames=50):
42
+ # ======= manually clear the conversation
43
+ # state = conv_templates[conv_mode].copy()
44
+ # state_ = conv_templates[conv_mode].copy()
45
+ # # =======
46
+ flag = 1
47
+ if not textbox_in:
48
+ if len(state_.messages) > 0:
49
+ textbox_in = state_.messages[-1][1]
50
+ state_.messages.pop(-1)
51
+ flag = 0
52
+ else:
53
+ return "Please enter instruction"
54
+ # else:
55
+ # if state is not None and state_ is not None:
56
+ # # reset conversations
57
+ # state.messages = []
58
+ # state_.messages = []
59
+
60
+ print("Video", video) # 잘 들어감
61
+ print("Images_tensor", images_tensor) # None
62
+ print("Textbox_IN", textbox_in) # 잘 들어감
63
+ print("State", state) # None
64
+ print("State_", state_) # None
65
+ # print(len(state_.messages))
66
+
67
+ video = video if video else "none"
68
+
69
+ if type(state) is not Conversation:
70
+ state = conv_templates[conv_mode].copy()
71
+ state_ = conv_templates[conv_mode].copy()
72
+ images_tensor = []
73
+
74
+ first_run = False if len(state.messages) > 0 else True
75
+
76
+ text_en_in = textbox_in.replace("picture", "image")
77
+
78
+ image_processor = handler.image_processor
79
+ assert os.path.exists(video)
80
+ if os.path.splitext(video)[-1].lower() in video_ext: # video extension
81
+ video_decode_backend = 'opencv'
82
+ elif os.path.splitext(os.listdir(video)[0]).lower() in image_ext: # frames folder
83
+ video_decode_backend = 'frames'
84
+ else:
85
+ raise ValueError(f'Support video of {video_ext} and frames of {image_ext}, but found {os.path.splitext(video)[-1].lower()}')
86
+
87
+ frames = load_video_into_frames(video, video_decode_backend=video_decode_backend, num_frames=num_frames)
88
+ tensor = process_images(frames, image_processor, argparse.Namespace(image_aspect_ratio='pad'))
89
+ # tensor = video_processor(video, return_tensors='pt')['pixel_values'][0]
90
+ # print(tensor.shape)
91
+ tensor = tensor.to(handler.model.device, dtype=dtype)
92
+ # images_tensor.append(tensor)
93
+ images_tensor = tensor
94
+
95
+ if handler.model.config.mm_use_im_start_end:
96
+ text_en_in = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + text_en_in
97
+ else:
98
+ text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in
99
+ text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_)
100
+ state_.messages[-1] = (state_.roles[1], text_en_out)
101
+
102
+ text_en_out = text_en_out.split('#')[0]
103
+ textbox_out = text_en_out
104
+
105
+ show_images = ""
106
+ if os.path.exists(video):
107
+ filename = save_video_to_local(video)
108
+ show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={filename}"></video>'
109
+ if flag:
110
+ state.append_message(state.roles[0], textbox_in + "\n" + show_images)
111
+ state.append_message(state.roles[1], textbox_out)
112
+
113
+ return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, \
114
+ gr.update(value=video if os.path.exists(video) else None, interactive=True))
115
+
116
+
117
+ def regenerate(state, state_):
118
+ state.messages.pop(-1)
119
+ state_.messages.pop(-1)
120
+ if len(state.messages) > 0:
121
+ return state, state_, state.to_gradio_chatbot(), False
122
+ return (state, state_, state.to_gradio_chatbot(), True)
123
+
124
+
125
+ def clear_history(state, state_):
126
+ state = conv_templates[conv_mode].copy()
127
+ state_ = conv_templates[conv_mode].copy()
128
+ return (gr.update(value=None, interactive=True),
129
+ gr.update(value=None, interactive=True), \
130
+ gr.update(value=None, interactive=True), \
131
+ True, state, state_, state.to_gradio_chatbot(), [])
132
+
133
+
134
+ # ==== CHANGE HERE ====
135
+ # conv_mode = "llava_v1"
136
+ # model_path = 'LanguageBind/Video-LLaVA-7B'
137
+ # FIXME!!!
138
+
139
+ conv_mode = "llava_v0"
140
+ model_path = 'SNUMPR/vlm_rlaif_video_llava_7b'
141
+ # model_path = '/dataset/yura/vlm-rlaif/pretrained/final_models/Video_LLaVA_VLM_RLAIF_merged'
142
+ cache_dir = './cache_dir'
143
+ device = 'cuda'
144
+ # device = 'cpu'
145
+ load_8bit = True
146
+ load_4bit = False
147
+ dtype = torch.float16
148
+ # =============
149
+
150
+ handler = Chat(model_path, conv_mode=conv_mode, load_8bit=load_8bit, load_4bit=load_8bit, device=device, cache_dir=cache_dir)
151
+ # handler.model.to(dtype=dtype)
152
+ if not os.path.exists("temp"):
153
+ os.makedirs("temp")
154
+
155
+ app = FastAPI()
156
+
157
+
158
+ textbox = gr.Textbox(
159
+ show_label=False, placeholder="Enter text and press ENTER", container=False
160
+ )
161
+ with gr.Blocks(title='VLM-RLAIF', theme=gr.themes.Default(), css=block_css) as demo:
162
+ gr.Markdown(title_markdown)
163
+ state = gr.State()
164
+ state_ = gr.State()
165
+ first_run = gr.State()
166
+ images_tensor = gr.State()
167
+
168
+ # image1 = gr.Image(label="Input Image", type="filepath")
169
+ with gr.Row():
170
+ with gr.Column(scale=3):
171
+ video = gr.Video(label="Input Video")
172
+
173
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
174
+ gr.Examples(
175
+ examples=[
176
+ [
177
+ f"{cur_dir}/examples/sample_demo_1.mp4",
178
+ "Why is this video funny?",
179
+ ],
180
+ [
181
+ f"{cur_dir}/examples/sample_demo_3.mp4",
182
+ "Can you identify any safety hazards in this video?"
183
+ ],
184
+ [
185
+ f"{cur_dir}/examples/sample_demo_9.mp4",
186
+ "Describe the video.",
187
+ ],
188
+ [
189
+ f"{cur_dir}/examples/sample_demo_22.mp4",
190
+ "Describe the activity in the video.",
191
+ ],
192
+ ],
193
+ inputs=[video, textbox],
194
+ )
195
+
196
+ with gr.Column(scale=7):
197
+ chatbot = gr.Chatbot(label="VLM_RLAIF", bubble_full_width=True).style(height=750)
198
+ with gr.Row():
199
+ with gr.Column(scale=8):
200
+ textbox.render()
201
+ with gr.Column(scale=1, min_width=50):
202
+ submit_btn = gr.Button(
203
+ value="Send", variant="primary", interactive=True
204
+ )
205
+ with gr.Row(elem_id="buttons") as button_row:
206
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
207
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=True)
208
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
209
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
210
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
211
+ # clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
212
+
213
+ gr.Markdown(tos_markdown)
214
+ gr.Markdown(learn_more_markdown)
215
+
216
+ submit_btn.click(generate, [video, textbox, first_run, state, state_, images_tensor],
217
+ [state, state_, chatbot, first_run, textbox, images_tensor, video])
218
+ # submit_btn.click(generate, [video, textbox, first_run, state, state_, images_tensor],
219
+ # [state, state_, chatbot, first_run, textbox, images_tensor, video])
220
+
221
+ regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then(
222
+ generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video])
223
+ # generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video])
224
+
225
+ # clear_btn.click(clear_history, [state, state_],
226
+ # [image1, video, textbox, first_run, state, state_, chatbot, images_tensor])
227
+ # [video, textbox, first_run, state, state_, chatbot, images_tensor])
228
+
229
+ # app = gr.mount_gradio_app(app, demo, path="/")
230
+ demo.launch(share=True)
231
+ # demo.launch()
232
+
233
+ # uvicorn videollava.serve.gradio_web_server:app
234
+ # python -m videollava.serve.gradio_web_server
model_worker.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker executes the model.
3
+ """
4
+ import argparse
5
+ import asyncio
6
+ import json
7
+ import time
8
+ import threading
9
+ import uuid
10
+
11
+ from fastapi import FastAPI, Request, BackgroundTasks
12
+ from fastapi.responses import StreamingResponse
13
+ import requests
14
+ import torch
15
+ import uvicorn
16
+ from functools import partial
17
+
18
+ from videollava.constants import WORKER_HEART_BEAT_INTERVAL
19
+ from videollava.utils import (build_logger, server_error_msg,
20
+ pretty_print_semaphore)
21
+ from videollava.model.builder import load_pretrained_model
22
+ from videollava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
23
+ from videollava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
+ from transformers import TextIteratorStreamer
25
+ from threading import Thread
26
+
27
+
28
+ GB = 1 << 30
29
+
30
+ worker_id = str(uuid.uuid4())[:6]
31
+ logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
32
+ global_counter = 0
33
+
34
+ model_semaphore = None
35
+
36
+
37
+ def heart_beat_worker(controller):
38
+
39
+ while True:
40
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
41
+ controller.send_heart_beat()
42
+
43
+
44
+ class ModelWorker:
45
+ def __init__(self, controller_addr, worker_addr,
46
+ worker_id, no_register,
47
+ model_path, model_base, model_name,
48
+ load_8bit, load_4bit, device):
49
+ self.controller_addr = controller_addr
50
+ self.worker_addr = worker_addr
51
+ self.worker_id = worker_id
52
+ if model_path.endswith("/"):
53
+ model_path = model_path[:-1]
54
+ if model_name is None:
55
+ model_paths = model_path.split("/")
56
+ if model_paths[-1].startswith('checkpoint-'):
57
+ self.model_name = model_paths[-2] + "_" + model_paths[-1]
58
+ else:
59
+ self.model_name = model_paths[-1]
60
+ else:
61
+ self.model_name = model_name
62
+
63
+ self.device = device
64
+ logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
65
+ self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
66
+ model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
67
+ self.is_multimodal = 'llava' in self.model_name.lower()
68
+
69
+ if not no_register:
70
+ self.register_to_controller()
71
+ self.heart_beat_thread = threading.Thread(
72
+ target=heart_beat_worker, args=(self,))
73
+ self.heart_beat_thread.start()
74
+
75
+ def register_to_controller(self):
76
+ logger.info("Register to controller")
77
+
78
+ url = self.controller_addr + "/register_worker"
79
+ data = {
80
+ "worker_name": self.worker_addr,
81
+ "check_heart_beat": True,
82
+ "worker_status": self.get_status()
83
+ }
84
+ r = requests.post(url, json=data)
85
+ assert r.status_code == 200
86
+
87
+ def send_heart_beat(self):
88
+ logger.info(f"Send heart beat. Models: {[self.model_name]}. "
89
+ f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
90
+ f"global_counter: {global_counter}")
91
+
92
+ url = self.controller_addr + "/receive_heart_beat"
93
+
94
+ while True:
95
+ try:
96
+ ret = requests.post(url, json={
97
+ "worker_name": self.worker_addr,
98
+ "queue_length": self.get_queue_length()}, timeout=5)
99
+ exist = ret.json()["exist"]
100
+ break
101
+ except requests.exceptions.RequestException as e:
102
+ logger.error(f"heart beat error: {e}")
103
+ time.sleep(5)
104
+
105
+ if not exist:
106
+ self.register_to_controller()
107
+
108
+ def get_queue_length(self):
109
+ if model_semaphore is None:
110
+ return 0
111
+ else:
112
+ return args.limit_model_concurrency - model_semaphore._value + (len(
113
+ model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
114
+
115
+ def get_status(self):
116
+ return {
117
+ "model_names": [self.model_name],
118
+ "speed": 1,
119
+ "queue_length": self.get_queue_length(),
120
+ }
121
+
122
+ @torch.inference_mode()
123
+ def generate_stream(self, params):
124
+ tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
125
+
126
+ prompt = params["prompt"]
127
+ ori_prompt = prompt
128
+ images = params.get("images", None)
129
+ num_image_tokens = 0
130
+ if images is not None and len(images) > 0 and self.is_multimodal:
131
+ if len(images) > 0:
132
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
133
+ raise ValueError("Number of images does not match number of <image> tokens in prompt")
134
+
135
+ images = [load_image_from_base64(image) for image in images]
136
+ images = process_images(images, image_processor, model.config)
137
+
138
+ if type(images) is list:
139
+ images = [image.to(self.model.device, dtype=torch.float16) for image in images]
140
+ else:
141
+ images = images.to(self.model.device, dtype=torch.float16)
142
+
143
+ replace_token = DEFAULT_IMAGE_TOKEN
144
+ if getattr(self.model.config, 'mm_use_im_start_end', False):
145
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
146
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
147
+
148
+ num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
149
+ else:
150
+ images = None
151
+ image_args = {"images": images}
152
+ else:
153
+ images = None
154
+ image_args = {}
155
+
156
+ temperature = float(params.get("temperature", 1.0))
157
+ top_p = float(params.get("top_p", 1.0))
158
+ max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
159
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
160
+ stop_str = params.get("stop", None)
161
+ do_sample = True if temperature > 0.001 else False
162
+
163
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
164
+ keywords = [stop_str]
165
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
166
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
167
+
168
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
169
+
170
+ if max_new_tokens < 1:
171
+ yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
172
+ return
173
+
174
+ thread = Thread(target=model.generate, kwargs=dict(
175
+ inputs=input_ids,
176
+ do_sample=do_sample,
177
+ temperature=temperature,
178
+ top_p=top_p,
179
+ max_new_tokens=max_new_tokens,
180
+ streamer=streamer,
181
+ stopping_criteria=[stopping_criteria],
182
+ use_cache=True,
183
+ **image_args
184
+ ))
185
+ thread.start()
186
+
187
+ generated_text = ori_prompt
188
+ for new_text in streamer:
189
+ generated_text += new_text
190
+ if generated_text.endswith(stop_str):
191
+ generated_text = generated_text[:-len(stop_str)]
192
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
193
+
194
+ def generate_stream_gate(self, params):
195
+ try:
196
+ for x in self.generate_stream(params):
197
+ yield x
198
+ except ValueError as e:
199
+ print("Caught ValueError:", e)
200
+ ret = {
201
+ "text": server_error_msg,
202
+ "error_code": 1,
203
+ }
204
+ yield json.dumps(ret).encode() + b"\0"
205
+ except torch.cuda.CudaError as e:
206
+ print("Caught torch.cuda.CudaError:", e)
207
+ ret = {
208
+ "text": server_error_msg,
209
+ "error_code": 1,
210
+ }
211
+ yield json.dumps(ret).encode() + b"\0"
212
+ except Exception as e:
213
+ print("Caught Unknown Error", e)
214
+ ret = {
215
+ "text": server_error_msg,
216
+ "error_code": 1,
217
+ }
218
+ yield json.dumps(ret).encode() + b"\0"
219
+
220
+
221
+ app = FastAPI()
222
+
223
+
224
+ def release_model_semaphore(fn=None):
225
+ model_semaphore.release()
226
+ if fn is not None:
227
+ fn()
228
+
229
+
230
+ @app.post("/worker_generate_stream")
231
+ async def generate_stream(request: Request):
232
+ global model_semaphore, global_counter
233
+ global_counter += 1
234
+ params = await request.json()
235
+
236
+ if model_semaphore is None:
237
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
238
+ await model_semaphore.acquire()
239
+ worker.send_heart_beat()
240
+ generator = worker.generate_stream_gate(params)
241
+ background_tasks = BackgroundTasks()
242
+ background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
243
+ return StreamingResponse(generator, background=background_tasks)
244
+
245
+
246
+ @app.post("/worker_get_status")
247
+ async def get_status(request: Request):
248
+ return worker.get_status()
249
+
250
+
251
+ if __name__ == "__main__":
252
+ parser = argparse.ArgumentParser()
253
+ parser.add_argument("--host", type=str, default="localhost")
254
+ parser.add_argument("--port", type=int, default=21002)
255
+ parser.add_argument("--worker-address", type=str,
256
+ default="http://localhost:21002")
257
+ parser.add_argument("--controller-address", type=str,
258
+ default="http://localhost:21001")
259
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
260
+ parser.add_argument("--model-base", type=str, default=None)
261
+ parser.add_argument("--model-name", type=str)
262
+ parser.add_argument("--device", type=str, default="cuda")
263
+ parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
264
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
265
+ parser.add_argument("--stream-interval", type=int, default=1)
266
+ parser.add_argument("--no-register", action="store_true")
267
+ parser.add_argument("--load-8bit", action="store_true")
268
+ parser.add_argument("--load-4bit", action="store_true")
269
+ args = parser.parse_args()
270
+ logger.info(f"args: {args}")
271
+
272
+ if args.multi_modal:
273
+ logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
274
+
275
+ worker = ModelWorker(args.controller_address,
276
+ args.worker_address,
277
+ worker_id,
278
+ args.no_register,
279
+ args.model_path,
280
+ args.model_base,
281
+ args.model_name,
282
+ args.load_8bit,
283
+ args.load_4bit,
284
+ args.device)
285
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
processing_utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import TextStreamer
3
+ import numpy as np
4
+ import os
5
+ import json
6
+ import torch
7
+
8
+ import numpy as np
9
+ import base64
10
+ from PIL import Image
11
+ from io import BytesIO
12
+ import matplotlib.pyplot as plt
13
+ from torchvision.transforms import Compose, Lambda, ToTensor
14
+ from torchvision import transforms
15
+ from transformers import ProcessorMixin, BatchEncoding
16
+ from transformers.image_processing_utils import BatchFeature
17
+ from pytorchvideo.data.encoded_video import EncodedVideo
18
+ from torchvision.transforms import Compose, Lambda, ToTensor
19
+ from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo, CenterCropVideo
20
+ from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale, UniformTemporalSubsample
21
+
22
+
23
+ def load_frames(frames_dir):
24
+ results = []
25
+ frame_names = os.listdir(frames_dir)
26
+ frame_names.sort()
27
+ for frame_name in frame_names:
28
+ image_path = f"{frames_dir}/{frame_name}"
29
+ results.append(image_path)
30
+ return results
31
+
32
+ def sample_frames(frames, num_segments):
33
+ duration = len(frames)
34
+ frame_id_array = np.linspace(0, duration-1, num_segments, dtype=int)
35
+ frame_id_list = frame_id_array.tolist()
36
+
37
+ sampled_frames = []
38
+ for frame_idx in frame_id_list:
39
+ single_frame_path = frames[frame_idx]
40
+ sampled_frames.append(single_frame_path)
41
+ return sampled_frames
42
+
43
+
44
+ class VideoProcessor:
45
+ def __init__(self, image_transform):
46
+ self.image_transform = image_transform
47
+
48
+ def __call__(self, video_path, transform=None,
49
+ video_decode_backend='opencv',
50
+ clip_start_sec=0.0, clip_end_sec=None,
51
+ num_frames=50, **kwargs):
52
+ if transform is None: transform = self.image_transform
53
+ if video_decode_backend == 'pytorchvideo':
54
+ # decord pyav
55
+ video = EncodedVideo.from_path(video_path, decoder="decord", decode_audio=False)
56
+ duration = video.duration
57
+ start_sec = clip_start_sec # secs
58
+ end_sec = clip_end_sec if clip_end_sec is not None else duration # secs
59
+ video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec)
60
+ video_outputs = transform(video_data)
61
+
62
+ elif video_decode_backend == 'decord':
63
+ import decord
64
+ from decord import VideoReader, cpu
65
+ decord.bridge.set_bridge('torch')
66
+ decord_vr = VideoReader(video_path, ctx=cpu(0))
67
+ ori_duration = len(decord_vr)
68
+ # frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int)
69
+ fps_vid = decord_vr.get_avg_fps()
70
+ valid_duration = min(int(fps_vid * 10), ori_duration)
71
+ frame_id_list = np.linspace(0, valid_duration-1, num_frames, dtype=int)
72
+ video_data = decord_vr.get_batch(frame_id_list)
73
+ video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W)
74
+ video_outputs = transform(video_data)
75
+
76
+ elif video_decode_backend == 'opencv':
77
+ import cv2
78
+ cv2_vr = cv2.VideoCapture(video_path)
79
+ duration = int(cv2_vr.get(cv2.CAP_PROP_FRAME_COUNT))
80
+ frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int)
81
+
82
+ video_data = []
83
+ for frame_idx in frame_id_list:
84
+ cv2_vr.set(1, frame_idx)
85
+ _, frame = cv2_vr.read()
86
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
87
+ video_data.append(torch.from_numpy(frame).permute(2, 0, 1))
88
+ cv2_vr.release()
89
+ video_data = torch.stack(video_data, dim=1)
90
+ video_outputs = transform(video_data)
91
+
92
+ elif video_decode_backend == 'frames':
93
+ # FIXME does not input start and end clip timestamps. Require duration info to deal with.
94
+ frames = load_frames(video_path)
95
+ frames = sample_frames(frames, num_frames)
96
+ to_tensor = ToTensor()
97
+ video_data = torch.stack([to_tensor(_) for _ in frames]).permute(1, 0, 2, 3) # (T, C, H, W) -> (C, T, H, W)
98
+ else:
99
+ raise NameError('video_decode_backend should specify in (pytorchvideo, decord, opencv, frames)')
register_worker.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Manually register workers.
3
+
4
+ Usage:
5
+ python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
6
+ """
7
+
8
+ import argparse
9
+
10
+ import requests
11
+
12
+ if __name__ == "__main__":
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--controller-address", type=str)
15
+ parser.add_argument("--worker-name", type=str)
16
+ parser.add_argument("--check-heart-beat", action="store_true")
17
+ args = parser.parse_args()
18
+
19
+ url = args.controller_address + "/register_worker"
20
+ data = {
21
+ "worker_name": args.worker_name,
22
+ "check_heart_beat": args.check_heart_beat,
23
+ "worker_status": None,
24
+ }
25
+ r = requests.post(url, json=data)
26
+ assert r.status_code == 200
test_message.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+
4
+ import requests
5
+
6
+ from videollava.conversation import default_conversation
7
+
8
+
9
+ def main():
10
+ if args.worker_address:
11
+ worker_addr = args.worker_address
12
+ else:
13
+ controller_addr = args.controller_address
14
+ ret = requests.post(controller_addr + "/refresh_all_workers")
15
+ ret = requests.post(controller_addr + "/list_models")
16
+ models = ret.json()["models"]
17
+ models.sort()
18
+ print(f"Models: {models}")
19
+
20
+ ret = requests.post(controller_addr + "/get_worker_address",
21
+ json={"model": args.model_name})
22
+ worker_addr = ret.json()["address"]
23
+ print(f"worker_addr: {worker_addr}")
24
+
25
+ if worker_addr == "":
26
+ return
27
+
28
+ conv = default_conversation.copy()
29
+ conv.append_message(conv.roles[0], args.message)
30
+ prompt = conv.get_prompt()
31
+
32
+ headers = {"User-Agent": "LLaVA Client"}
33
+ pload = {
34
+ "model": args.model_name,
35
+ "prompt": prompt,
36
+ "max_new_tokens": args.max_new_tokens,
37
+ "temperature": 0.7,
38
+ "stop": conv.sep,
39
+ }
40
+ response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
41
+ json=pload, stream=True)
42
+
43
+ print(prompt.replace(conv.sep, "\n"), end="")
44
+ for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
45
+ if chunk:
46
+ data = json.loads(chunk.decode("utf-8"))
47
+ output = data["text"].split(conv.sep)[-1]
48
+ print(output, end="\r")
49
+ print("")
50
+
51
+
52
+ if __name__ == "__main__":
53
+ parser = argparse.ArgumentParser()
54
+ parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
55
+ parser.add_argument("--worker-address", type=str)
56
+ parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
57
+ parser.add_argument("--max-new-tokens", type=int, default=32)
58
+ parser.add_argument("--message", type=str, default=
59
+ "Tell me a story with more than 1000 words.")
60
+ args = parser.parse_args()
61
+
62
+ main()
utils.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+
3
+ import requests
4
+ from PIL import Image
5
+
6
+
7
+ def load_image(image_file):
8
+ if image_file.startswith('http://') or image_file.startswith('https://'):
9
+ response = requests.get(image_file)
10
+ image = Image.open(BytesIO(response.content)).convert('RGB')
11
+ else:
12
+ image = Image.open(image_file).convert('RGB')
13
+ return image
14
+
15
+ video_ext = ['.mp4', '.mov', '.mkv', '.avi']
16
+ image_ext = ['.jpg', '.png', '.bmp', '.jpeg']