File size: 14,472 Bytes
c3b0d93
 
 
 
4f234f9
 
 
 
 
6a4537e
4f234f9
 
 
 
 
 
 
 
 
 
6a4537e
4f234f9
 
 
 
 
 
 
 
6a4537e
4f234f9
6a4537e
4f234f9
 
 
 
 
 
a30b8db
4f234f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a4537e
 
4f234f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a4537e
4f234f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a4537e
4f234f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
import os
os.system("pip uninstall -y gradio")
os.system("pip install gradio==3.43.2")

import argparse
import sys
import os
# import cv2
import glob
import gradio as gr
import numpy as np
import json
from PIL import Image
from tqdm import tqdm
from pathlib import Path
import uvicorn
from fastapi.staticfiles import StaticFiles
import random
import time
import requests

from fastapi import FastAPI
from conversation import SeparatorStyle, conv_templates, default_conversation
from utils import (
    build_logger,
    moderation_msg,
    server_error_msg,
)
from config import cur_conv

logger = build_logger("gradio_web_server", "gradio_web_server.log")

headers = {"Content-Type": "application/json"}

# create a FastAPI app
app = FastAPI()
# # create a static directory to store the static files
# static_dir = Path('/data/Multimodal-RAG/GenerativeAIExamples/ChatQnA/langchain/redis/chips-making-deals/')
static_dir = Path('/')

# mount FastAPI StaticFiles server
app.mount("/static", StaticFiles(directory=static_dir), name="static")

theme = gr.themes.Base(
    primary_hue=gr.themes.Color(
        c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#00377c", c700="#00377c", c800="#1e40af", c900="#1e3a8a", c950="#0a0c2b"),
    secondary_hue=gr.themes.Color(
        c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#0054ae", c700="#0054ae", c800="#1e40af", c900="#1e3a8a", c950="#1d3660"),
).set(
    body_background_fill_dark='*primary_950',
    body_text_color_dark='*neutral_300',
    border_color_accent='*primary_700',
    border_color_accent_dark='*neutral_800',
    block_background_fill_dark='*primary_950',
    block_border_width='2px',
    block_border_width_dark='2px',
    button_primary_background_fill_dark='*primary_500',
    button_primary_border_color_dark='*primary_500'
)

css='''
    @font-face {
        font-family: IntelOne;
        src: url("file/assets/intelone-bodytext-font-family-regular.ttf");
    }
'''

##     <td style="border-bottom:0"><img src="file/assets/DCAI_logo.png" height="300" width="300"></td>
html_title = '''
<table>
<tr style="height:150px">
    <td style="border-bottom:0"><img src="file/assets/intel-labs.png" height="100" width="100"></td>
    <td style="border-bottom:0; vertical-align:bottom"> 
    <p style="font-size:xx-large;font-family:IntelOne, Georgia, sans-serif;color: white;">
     Cognitive AI:
     <br>
     Multimodal RAG on Videos
    </p>
    </td>
    <td style="border-bottom:0;"><img src="file/assets/gaudi.png" width="100" height="100"></td>
    <td style="border-bottom:0;"><img src="file/assets/xeon.png" width="100" height="100"></td>
    <td style="border-bottom:0;"><img src="file/assets/IDC7.png" width="400" height="350"></td> 
</tr>
</table>

'''

debug = False
def print_debug(t):
    if debug:
        print(t)

# https://stackoverflow.com/a/57781047
# Resizes a image and maintains aspect ratio
# def maintain_aspect_ratio_resize(image, width=None, height=None, inter=cv2.INTER_AREA):
#     # Grab the image size and initialize dimensions
#     dim = None
#     (h, w) = image.shape[:2]

#     # Return original image if no need to resize
#     if width is None and height is None:
#         return image

#     # We are resizing height if width is none
#     if width is None:
#         # Calculate the ratio of the height and construct the dimensions
#         r = height / float(h)
#         dim = (int(w * r), height)
#     # We are resizing width if height is none
#     else:
#         # Calculate the ratio of the width and construct the dimensions
#         r = width / float(w)
#         dim = (width, int(h * r))

#     # Return the resized image
#     return cv2.resize(image, dim, interpolation=inter)

def time_to_frame(time, fps):
    '''
        convert time in seconds into frame number
    '''
    return int(time * fps - 1)

def str2time(strtime):
    strtime = strtime.strip('"')
    hrs, mins, seconds = [float(c) for c in strtime.split(':')]

    total_seconds = hrs * 60**2 + mins * 60 + seconds

    return total_seconds

def get_iframe(video_path: str, start: int = -1, end: int = -1):
    return f"""<video controls="controls" preload="metadata" src="{video_path}" width="540" height="310"></video>"""

#TODO
# def place(galleries, evt: gr.SelectData):
#     print(evt.value)
#     start_time = evt.value.split('||')[0].strip()
#     print(start_time)
#     # sub_video_id = evt.value.split('|')[-1]
#     if start_time in start_time_index_map.keys():
#         sub_video_id = start_time_index_map[start_time]
#     else:
#         sub_video_id = 0
#     path_to_sub_video = f"/static/video_embeddings/mp4.keynotes23/sub-videos/keynotes23_split{sub_video_id}.mp4"
#     # return evt.value
#     return get_iframe(path_to_sub_video)
    
# def process(text_query):
#     tmp_dir = os.environ.get('VID_CACHE_DIR', os.environ.get('TMPDIR', './video_embeddings'))
#     frames, transcripts = run_query(text_query, path=tmp_dir)
#     # return video_file_path, [(image, caption) for image, caption in zip(frame_paths, transcripts)]
#     return [(frame, caption) for frame, caption in zip(frames, transcripts)], ""

description = "This Space lets you engage with multimodal RAG on a video through a chat box."

no_change_btn = gr.Button.update()
enable_btn = gr.Button.update(interactive=True)
disable_btn = gr.Button.update(interactive=False)

# textbox = gr.Textbox(
#         show_label=False, placeholder="Enter text and press ENTER", container=False
# )



def clear_history(request: gr.Request):
    logger.info(f"clear_history. ip: {request.client.host}")
    state = cur_conv.copy()
    return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 1

def add_text(state, text, request: gr.Request):
    logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
    if len(text) <= 0 :
        state.skip_next = True
        return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 1

    text = text[:1536]  # Hard cut-off

    state.append_message(state.roles[0], text)
    state.append_message(state.roles[1], None)
    state.skip_next = False
    return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 1

def http_bot(
    state, request: gr.Request
):
    logger.info(f"http_bot. ip: {request.client.host}")
    start_tstamp = time.time()

    if state.skip_next:
        # This generate call is skipped due to invalid inputs
        path_to_sub_videos = state.get_path_to_subvideos()
        yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (no_change_btn,) * 1
        return

    if len(state.messages) == state.offset + 2:
        # First round of conversation
        new_state = cur_conv.copy()
        new_state.append_message(new_state.roles[0], state.messages[-2][1])
        new_state.append_message(new_state.roles[1], None)
        state = new_state

    # Construct prompt
    prompt = state.get_prompt()

    all_images = state.get_images(return_pil=False)

    # Make requests
    is_very_first_query = True
    if len(all_images) == 0:
        # first query need to do RAG
        pload = {
            "query": prompt,
        }
    else:
        # subsequence queries, no need to do Retrieval
        is_very_first_query = False
        pload = {
            "prompt": prompt,
            "path-to-image": all_images[0],
        }
    if is_very_first_query:
        url = worker_addr + "/v1/rag/chat"
    else:
        url = worker_addr + "/v1/rag/multi_turn_chat"
    logger.info(f"==== request ====\n{pload}")
    logger.info(f"==== url request ====\n{url}")
    #uncomment this for testing UI only
    # state.messages[-1][-1] = f"response {len(state.messages)}"
    # yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 1
    # return
    
    state.messages[-1][-1] = "▌"
    path_to_sub_videos = state.get_path_to_subvideos()
    yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (disable_btn,) * 1

    try:
        # Stream output
        response = requests.post(url, headers=headers, json=pload, timeout=100, stream=True)
        for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
            if chunk:
                res = json.loads(chunk.decode())
        ## old_method
        # if response.status_code == 200:
        #     cur_json = ""
        #     for chunk in response:
        #         # print('chunk is ---> ', chunk.decode('utf-8'))
        #         cur_json += chunk.decode('utf-8')
        #         try:
        #             res = json.loads(cur_json)
        #         except:
        #             # a whole json does not include in this chunk, need to concatenate with next chunk
        #             continue
        #         # successfully load json into res
        #         cur_json = ""
                if state.path_to_img is None and 'path-to-image' in res:
                    state.path_to_img = res['path-to-image']
                if state.video_title is None and 'title' in res:
                    state.video_title = res['title']
                if 'answer' in res:
                    # print(f"answer is {res['answer']}")
                    output = res["answer"]
                    # print(f"state.messages is {state.messages[-1][-1]}")
                    state.messages[-1][-1] = state.messages[-1][-1][:-1] + output + "▌"
                    path_to_sub_videos = state.get_path_to_subvideos()
                    yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (disable_btn,) * 1
                time.sleep(0.03)
        # else:
        #     raise requests.exceptions.RequestException()
    except requests.exceptions.RequestException as e:
        state.messages[-1][-1] = server_error_msg
        yield (state, state.to_gradio_chatbot(), None) + (
            enable_btn,
        )
        return

    state.messages[-1][-1] = state.messages[-1][-1][:-1]
    path_to_sub_videos = state.get_path_to_subvideos()
    logger.info(path_to_sub_videos)
    yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (enable_btn,) * 1

    finish_tstamp = time.time()
    logger.info(f"{state.messages[-1][-1]}")

    # with open(get_conv_log_filename(), "a") as fout:
    #     data = {
    #         "tstamp": round(finish_tstamp, 4),
    #         "url": url,
    #         "start": round(start_tstamp, 4),
    #         "finish": round(start_tstamp, 4),
    #         "state": state.dict(),
    #     }
    #     fout.write(json.dumps(data) + "\n")
    return

dropdown_list = [
    "What did Intel present at Nasdaq?",
    "From Chips Act Funding Announcement, by which year is Intel committed to Net Zero gas emissions?",
    "What percentage of renewable energy is Intel planning to use?",
    "a band playing music",
    "Which US state is Silicon Desert referred to?",
    "and which US state is Silicon Forest referred to?",
    "How do trigate fins work?",
    "What is the advantage of trigate over planar transistors?",
    "What are key objectives of transistor design?",
    "How fast can transistors switch?",
]

with gr.Blocks(theme=theme, css=css) as demo:
    # gr.Markdown(description)
    state = gr.State(default_conversation.copy())
    gr.HTML(value=html_title)
    with gr.Row():
        with gr.Column(scale=4):
            video = gr.Video(height=512, width=512, elem_id="video" )
        with gr.Column(scale=7):
            chatbot = gr.Chatbot(
                        elem_id="chatbot", label="Multimodal RAG Chatbot", height=450
                )
            with gr.Row():
                with gr.Column(scale=8):
                    # textbox.render()
                    textbox = gr.Dropdown(
                        dropdown_list,
                        allow_custom_value=True,
                        # show_label=False,
                        # container=False,
                        label="Query",
                        info="Enter your query here or choose a sample from the dropdown list!"
                    )
                with gr.Column(scale=1, min_width=50):
                    submit_btn = gr.Button(
                        value="Send", variant="primary", interactive=True
                    )
            with gr.Row(elem_id="buttons") as button_row:
                clear_btn = gr.Button(value="🗑️  Clear history", interactive=False)
    # Register listeners
    btn_list = [clear_btn]

    clear_btn.click(
        clear_history, None, [state, chatbot, textbox, video] + btn_list
    )

    # textbox.submit(
    #     add_text,
    #     [state, textbox],
    #     [state, chatbot, textbox,] + btn_list,
    # ).then(
    #     http_bot,
    #     [state, ],
    #     [state, chatbot, video] + btn_list,
    # )
    
    submit_btn.click(
        add_text,
        [state, textbox],
        [state, chatbot, textbox,] + btn_list,
    ).then(
        http_bot,
        [state, ],
        [state, chatbot, video] + btn_list,
    )

    print_debug('Beginning')
    # btn.click(fn=process, 
    #     inputs=[text_query],
    #     # outputs=[video_player, gallery],
    #     outputs=[gallery, html],
              
    # )
    # gallery.select(place, [gallery], [html])
demo.queue()
app = gr.mount_gradio_app(app, demo, path='/')
share = False
enable_queue = True
# try:
#     demo.queue(concurrency_count=3)#, enable_queue=False)
#     demo.launch(enable_queue=enable_queue, share=share, server_port=17808, server_name='0.0.0.0')
# #BATCH -w isl-gpu48 
# except:
#     demo.launch(enable_queue=False, share=share, server_port=17808, server_name='0.0.0.0')

# serve the app
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="0.0.0.0")
    parser.add_argument("--port", type=int, default=7899)
    parser.add_argument("--concurrency-count", type=int, default=20)
    parser.add_argument("--share", action="store_true")
    parser.add_argument("--worker-address", type=str, default="198.175.88.247")
    parser.add_argument("--worker-port", type=int, default=7899)

    args = parser.parse_args()
    logger.info(f"args: {args}")
    global worker_addr 
    worker_addr = f"http://{args.worker_address}:{args.worker_port}"
    uvicorn.run(app, host=args.host, port=args.port)
    
# for i in examples:
#     print(f'Processing {i[0]}')
#     results = process(*i)
# print(f'{len(results[0])} results returned')