saicharan1234 commited on
Commit
8d7815c
·
verified ·
1 Parent(s): 3129acd

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +216 -0
main.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import io
5
+ from PIL import Image
6
+ import numpy as np
7
+ import torch
8
+ import torch.backends.cudnn as cudnn
9
+ import gradio as gr
10
+
11
+ from minigpt4.common.config import Config
12
+ from minigpt4.common.dist_utils import get_rank
13
+ from minigpt4.common.registry import registry
14
+ from minigpt4.conversation.conversation import Chat, CONV_VISION
15
+ from fastapi import FastAPI, HTTPException, File, UploadFile,Form
16
+ from fastapi.responses import RedirectResponse
17
+ from fastapi.middleware.cors import CORSMiddleware
18
+ from pydantic import BaseModel
19
+ from PIL import Image
20
+ import io
21
+ import uvicorn
22
+ # imports modules for registration
23
+ from minigpt4.datasets.builders import *
24
+ from minigpt4.models import *
25
+ from minigpt4.processors import *
26
+ from minigpt4.runners import *
27
+ from minigpt4.tasks import *
28
+
29
+
30
+ def parse_args():
31
+ parser = argparse.ArgumentParser(description="Demo")
32
+ parser.add_argument("--cfg-path", type=str, default='eval_configs/minigpt4.yaml',
33
+ help="path to configuration file.")
34
+ parser.add_argument(
35
+ "--options",
36
+ nargs="+",
37
+ help="override some settings in the used config, the key-value pair "
38
+ "in xxx=yyy format will be merged into config file (deprecate), "
39
+ "change to --cfg-options instead.",
40
+ )
41
+ args = parser.parse_args()
42
+ return args
43
+
44
+
45
+ def setup_seeds(config):
46
+ seed = config.run_cfg.seed + get_rank()
47
+
48
+ random.seed(seed)
49
+ np.random.seed(seed)
50
+ torch.manual_seed(seed)
51
+
52
+ cudnn.benchmark = False
53
+ cudnn.deterministic = True
54
+
55
+
56
+ # ========================================
57
+ # Model Initialization
58
+ # ========================================
59
+
60
+ SHARED_UI_WARNING = f'''### [NOTE] It is possible that you are waiting in a lengthy queue.
61
+
62
+ You can duplicate and use it with a paid private GPU.
63
+
64
+ <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/Vision-CAIR/minigpt4?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-xl-dark.svg" alt="Duplicate Space"></a>
65
+
66
+ Alternatively, you can also use the demo on our [project page](https://minigpt-4.github.io).
67
+ '''
68
+
69
+ print('Initializing Chat')
70
+ cfg = Config(parse_args())
71
+
72
+ model_config = cfg.model_cfg
73
+ model_cls = registry.get_model_class(model_config.arch)
74
+ model = model_cls.from_config(model_config).to('cuda:0')
75
+
76
+ vis_processor_cfg = cfg.datasets_cfg.cc_align.vis_processor.train
77
+ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
78
+ chat = Chat(model, vis_processor)
79
+ print('Initialization Finished')
80
+
81
+ # ========================================
82
+ # Gradio Setting
83
+ # ========================================
84
+
85
+ app = FastAPI()
86
+ app.add_middleware(
87
+ CORSMiddleware,
88
+ allow_origins=["*"], # Replace "*" with your frontend domain
89
+ allow_credentials=True,
90
+ allow_methods=["GET", "POST"],
91
+ allow_headers=["*"],
92
+ )
93
+
94
+
95
+ class Item(BaseModel):
96
+ gr_img: UploadFile = File(..., description="Image file")
97
+ text_input: str = None
98
+
99
+ @app.get("/")
100
+ async def root():
101
+ return RedirectResponse(url="/docs")
102
+
103
+ @app.post("/process/")
104
+ async def process_item(
105
+ file: UploadFile = File(...),
106
+ prompt: str = Form(...),
107
+ ):
108
+ chat_state = CONV_VISION.copy()
109
+ img_list = []
110
+ chatbot=[]
111
+ pil_image = Image.open(io.BytesIO(await file.read()))
112
+ chat.upload_img(pil_image, chat_state, img_list)
113
+ chat.ask(prompt, chat_state)
114
+ chatbot = chatbot + [[prompt, None]]
115
+ llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature,
116
+ max_length=2000)[0]
117
+ chatbot[-1][1] = llm_message
118
+ return chatbot, chat_state, img_list
119
+
120
+
121
+ # if __name__ == "__main__":
122
+ # # Run the FastAPI app with Uvicorn
123
+ # uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
124
+
125
+
126
+
127
+
128
+ # def gradio_reset(chat_state, img_list):
129
+ # if chat_state is not None:
130
+ # chat_state.messages = []
131
+ # if img_list is not None:
132
+ # img_list = []
133
+ # return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first',
134
+ # interactive=False), gr.update(
135
+ # value="Upload & Start Chat", interactive=True), chat_state, img_list
136
+ #
137
+ #
138
+ # def upload_img(gr_img, text_input, chat_state):
139
+ # if gr_img is None:
140
+ # return None, None, gr.update(interactive=True), chat_state, None
141
+ # chat_state = CONV_VISION.copy()
142
+ # img_list = []
143
+ # llm_message = chat.upload_img(gr_img, chat_state, img_list)
144
+ # return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(
145
+ # value="Start Chatting", interactive=False), chat_state, img_list
146
+ #
147
+ #
148
+ # def gradio_ask(user_message, chatbot, chat_state):
149
+ # if len(user_message) == 0:
150
+ # return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
151
+ # chat.ask(user_message, chat_state)
152
+ # chatbot = chatbot + [[user_message, None]]
153
+ # return '', chatbot, chat_state
154
+ #
155
+ #
156
+ # def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
157
+ # llm_message = \
158
+ # chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature,
159
+ # max_length=2000)[0]
160
+ # chatbot[-1][1] = llm_message
161
+ # return chatbot, chat_state, img_list
162
+ #
163
+ #
164
+ # title = """<h1 align="center">Demo of MiniGPT-4</h1>"""
165
+ # description = """<h3>This is the demo of MiniGPT-4. Upload your images and start chatting!</h3>"""
166
+ # article = """<div style='display:flex; gap: 0.25rem; '><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a><a href='https://github.com/TsuTikgiau/blip2-llm/blob/release_prepare/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></div>
167
+ # """
168
+ #
169
+ # # TODO show examples below
170
+ #
171
+ # with gr.Blocks() as demo:
172
+ # gr.Markdown(title)
173
+ # gr.Markdown(SHARED_UI_WARNING)
174
+ # gr.Markdown(description)
175
+ # gr.Markdown(article)
176
+ #
177
+ # with gr.Row():
178
+ # with gr.Column(scale=0.5):
179
+ # image = gr.Image(type="pil")
180
+ # upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
181
+ # clear = gr.Button("Restart")
182
+ #
183
+ # num_beams = gr.Slider(
184
+ # minimum=1,
185
+ # maximum=5,
186
+ # value=1,
187
+ # step=1,
188
+ # interactive=True,
189
+ # label="beam search numbers)",
190
+ # )
191
+ #
192
+ # temperature = gr.Slider(
193
+ # minimum=0.1,
194
+ # maximum=2.0,
195
+ # value=1.0,
196
+ # step=0.1,
197
+ # interactive=True,
198
+ # label="Temperature",
199
+ # )
200
+ #
201
+ # with gr.Column():
202
+ # chat_state = gr.State()
203
+ # img_list = gr.State()
204
+ # chatbot = gr.Chatbot(label='MiniGPT-4')
205
+ # text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
206
+ #
207
+ # upload_button.click(upload_img, [image, text_input, chat_state],
208
+ # [image, text_input, upload_button, chat_state, img_list])
209
+ #
210
+ # text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
211
+ # gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
212
+ # )
213
+ # clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
214
+ # queue=False)
215
+ #
216
+ # demo.launch(enable_queue=True)