Spaces:
Runtime error
Runtime error
Upload 10 files
Browse files- app.py +119 -0
- outputs/logs/Jan14_09-50-48.txt +53 -0
- outputs/logs/Jan14_11-23-28.txt +47 -0
- outputs/logs/Jan14_11-28-57.txt +47 -0
- outputs/uploaded/0.jpg +0 -0
- process_food.py +25 -0
- requirements.txt +158 -0
- utils.py +25 -0
- xtuner_config/.ipynb_checkpoints/internvl_v2_internlm2_2b_lora_finetune_food-checkpoint.py +184 -0
- xtuner_config/internvl_v2_internlm2_2b_lora_finetune_food.py +184 -0
app.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.backends.cudnn as cudnn
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
from utils import load_json, init_logger
|
9 |
+
from demo import ConversationalAgent, CustomTheme
|
10 |
+
|
11 |
+
FOOD_EXAMPLES = "./demo/food_for_demo.json"
|
12 |
+
# MODEL_PATH = "/root/share/new_models/OpenGVLab/InternVL2-2B"
|
13 |
+
MODEL_PATH = "./lr35_ep10"
|
14 |
+
OUTPUT_PATH = "./outputs"
|
15 |
+
|
16 |
+
def setup_seeds():
|
17 |
+
seed = 42
|
18 |
+
|
19 |
+
random.seed(seed)
|
20 |
+
np.random.seed(seed)
|
21 |
+
torch.manual_seed(seed)
|
22 |
+
|
23 |
+
cudnn.benchmark = False
|
24 |
+
cudnn.deterministic = True
|
25 |
+
|
26 |
+
|
27 |
+
def main():
|
28 |
+
setup_seeds()
|
29 |
+
# logging
|
30 |
+
init_logger(OUTPUT_PATH)
|
31 |
+
# food examples
|
32 |
+
food_examples = load_json(FOOD_EXAMPLES)
|
33 |
+
|
34 |
+
agent = ConversationalAgent(model_path=MODEL_PATH,
|
35 |
+
outputs_dir=OUTPUT_PATH)
|
36 |
+
|
37 |
+
theme = CustomTheme()
|
38 |
+
|
39 |
+
titles = [
|
40 |
+
"""<center><B><font face="Comic Sans MS" size=10>书生大模型实战营</font></B></center>""" ## Kalam:wght@700
|
41 |
+
"""<center><B><font face="Courier" size=5>「进阶岛」InternVL 多模态模型部署微调实践</font></B></center>"""
|
42 |
+
]
|
43 |
+
|
44 |
+
language = """Language: 中文 and English"""
|
45 |
+
with gr.Blocks(theme) as demo_chatbot:
|
46 |
+
for title in titles:
|
47 |
+
gr.Markdown(title)
|
48 |
+
# gr.Markdown(article)
|
49 |
+
gr.Markdown(language)
|
50 |
+
|
51 |
+
with gr.Row():
|
52 |
+
with gr.Column(scale=3):
|
53 |
+
start_btn = gr.Button("Start Chat", variant="primary", interactive=True)
|
54 |
+
clear_btn = gr.Button("Clear Context", interactive=False)
|
55 |
+
image = gr.Image(type="pil", interactive=False)
|
56 |
+
upload_btn = gr.Button("🖼️ Upload Image", interactive=False)
|
57 |
+
|
58 |
+
with gr.Accordion("Generation Settings"):
|
59 |
+
top_p = gr.Slider(minimum=0, maximum=1, step=0.1,
|
60 |
+
value=0.8,
|
61 |
+
interactive=True,
|
62 |
+
label='top-p value',
|
63 |
+
visible=True)
|
64 |
+
|
65 |
+
temperature = gr.Slider(minimum=0, maximum=1.5, step=0.1,
|
66 |
+
value=0.8,
|
67 |
+
interactive=True,
|
68 |
+
label='temperature',
|
69 |
+
visible=True)
|
70 |
+
|
71 |
+
with gr.Column(scale=7):
|
72 |
+
chat_state = gr.State()
|
73 |
+
chatbot = gr.Chatbot(label='InternVL2', height=800, avatar_images=((os.path.join(os.path.dirname(__file__), 'demo/user.png')), (os.path.join(os.path.dirname(__file__), "demo/bot.png"))))
|
74 |
+
text_input = gr.Textbox(label='User', placeholder="Please click the <Start Chat> button to start chat!", interactive=False)
|
75 |
+
gr.Markdown("### 输入示例")
|
76 |
+
def on_text_change(text):
|
77 |
+
return gr.update(interactive=True)
|
78 |
+
text_input.change(fn=on_text_change, inputs=text_input, outputs=text_input)
|
79 |
+
gr.Examples(
|
80 |
+
examples=[["图片中的食物通常属于哪个菜系?"],
|
81 |
+
["如果让你简单形容一下品尝图片中的食物的滋味,你会描述它"],
|
82 |
+
["去哪个地方游玩时应该品尝当地的特色美食图片中的食物?"],
|
83 |
+
["食用图片中的食物时,一般它上菜或摆盘时的特点是?"]],
|
84 |
+
inputs=[text_input]
|
85 |
+
)
|
86 |
+
|
87 |
+
with gr.Row():
|
88 |
+
gr.Markdown("### 食物快捷栏")
|
89 |
+
with gr.Row():
|
90 |
+
example_xinjiang_food = gr.Examples(examples=food_examples["新疆菜"], inputs=image, label="新疆菜")
|
91 |
+
example_sichuan_food = gr.Examples(examples=food_examples["川菜(四川,重庆)"], inputs=image, label="川菜(四川,重庆)")
|
92 |
+
example_xibei_food = gr.Examples(examples=food_examples["西北菜 (陕西,甘肃等地)"], inputs=image, label="西北菜 (陕西,甘肃等地)")
|
93 |
+
with gr.Row():
|
94 |
+
example_guizhou_food = gr.Examples(examples=food_examples["黔菜 (贵州)"], inputs=image, label="黔菜 (贵州)")
|
95 |
+
example_jiangsu_food = gr.Examples(examples=food_examples["苏菜(江苏)"], inputs=image, label="苏菜(江苏)")
|
96 |
+
example_guangdong_food = gr.Examples(examples=food_examples["粤菜(广东等地)"], inputs=image, label="粤菜(广东等地)")
|
97 |
+
with gr.Row():
|
98 |
+
example_hunan_food = gr.Examples(examples=food_examples["湘菜(湖南)"], inputs=image, label="湘菜(湖南)")
|
99 |
+
example_fujian_food = gr.Examples(examples=food_examples["闽菜(福建)"], inputs=image, label="闽菜(福建)")
|
100 |
+
example_zhejiang_food = gr.Examples(examples=food_examples["浙菜(浙江)"], inputs=image, label="浙菜(��江)")
|
101 |
+
with gr.Row():
|
102 |
+
example_dongbei_food = gr.Examples(examples=food_examples["东北菜 (黑龙江等地)"], inputs=image, label="东北菜 (黑龙江等地)")
|
103 |
+
|
104 |
+
|
105 |
+
start_btn.click(agent.start_chat, [chat_state], [text_input, start_btn, clear_btn, image, upload_btn, chat_state])
|
106 |
+
clear_btn.click(agent.restart_chat, [chat_state], [chatbot, text_input, start_btn, clear_btn, image, upload_btn, chat_state], queue=False)
|
107 |
+
upload_btn.click(agent.upload_image, [image, chatbot, chat_state], [image, chatbot, chat_state])
|
108 |
+
text_input.submit(
|
109 |
+
agent.respond,
|
110 |
+
inputs=[text_input, image, chatbot, top_p, temperature, chat_state],
|
111 |
+
outputs=[text_input, image, chatbot, chat_state]
|
112 |
+
)
|
113 |
+
|
114 |
+
demo_chatbot.launch(share=True, server_name="127.0.0.1", server_port=1096, allowed_paths=['./'])
|
115 |
+
demo_chatbot.queue()
|
116 |
+
|
117 |
+
|
118 |
+
if __name__ == "__main__":
|
119 |
+
main()
|
outputs/logs/Jan14_09-50-48.txt
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2025-01-14 09:50:53,500 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
2 |
+
2025-01-14 09:50:53,502 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
3 |
+
2025-01-14 09:50:53,502 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
4 |
+
2025-01-14 09:50:53,502 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
5 |
+
2025-01-14 09:50:53,537 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
6 |
+
2025-01-14 09:50:53,537 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
7 |
+
2025-01-14 09:50:53,537 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
8 |
+
2025-01-14 09:50:53,537 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
9 |
+
2025-01-14 09:50:54,641 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
10 |
+
2025-01-14 09:50:54,642 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
11 |
+
2025-01-14 09:50:54,642 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
12 |
+
2025-01-14 09:50:54,642 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
13 |
+
2025-01-14 09:51:27,929 modeling_internvl_chat.py[line:54] INFO || num_image_token: 256
|
14 |
+
2025-01-14 09:51:27,930 modeling_internvl_chat.py[line:55] INFO || ps_version: v2
|
15 |
+
2025-01-14 09:53:54,644 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
16 |
+
2025-01-14 09:53:54,645 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
17 |
+
2025-01-14 09:53:54,646 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
18 |
+
2025-01-14 09:53:54,646 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
19 |
+
2025-01-14 09:53:55,190 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
20 |
+
2025-01-14 09:53:55,190 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
21 |
+
2025-01-14 09:53:55,190 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
22 |
+
2025-01-14 09:53:55,190 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
23 |
+
2025-01-14 09:53:55,199 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
24 |
+
2025-01-14 09:53:55,199 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
25 |
+
2025-01-14 09:53:55,199 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
26 |
+
2025-01-14 09:53:55,199 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
27 |
+
2025-01-14 09:53:55,220 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
28 |
+
2025-01-14 09:53:55,220 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
29 |
+
2025-01-14 09:53:55,220 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
30 |
+
2025-01-14 09:53:55,220 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
31 |
+
2025-01-14 09:53:55,233 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
32 |
+
2025-01-14 09:53:55,233 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
33 |
+
2025-01-14 09:53:55,233 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
34 |
+
2025-01-14 09:53:55,233 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
35 |
+
2025-01-14 09:53:55,241 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
36 |
+
2025-01-14 09:53:55,242 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
37 |
+
2025-01-14 09:53:55,242 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
38 |
+
2025-01-14 09:53:55,242 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
39 |
+
2025-01-14 09:53:55,251 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
40 |
+
2025-01-14 09:53:55,251 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
41 |
+
2025-01-14 09:53:55,251 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
42 |
+
2025-01-14 09:53:55,251 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
43 |
+
2025-01-14 09:54:03,863 _client.py[line:1025] INFO || HTTP Request: GET https://checkip.amazonaws.com/ "HTTP/1.1 200 "
|
44 |
+
2025-01-14 09:54:03,950 _client.py[line:1025] INFO || HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
|
45 |
+
2025-01-14 09:54:11,354 _client.py[line:1025] INFO || HTTP Request: GET http://127.0.0.1:1096/startup-events "HTTP/1.1 200 OK"
|
46 |
+
2025-01-14 09:54:11,700 _client.py[line:1025] INFO || HTTP Request: HEAD http://127.0.0.1:1096/ "HTTP/1.1 200 OK"
|
47 |
+
2025-01-14 09:54:12,493 _client.py[line:1025] INFO || HTTP Request: GET https://api.gradio.app/v2/tunnel-request "HTTP/1.1 200 OK"
|
48 |
+
2025-01-14 10:07:49,759 agent.py[line:30] INFO || ==============================Start Chat==============================
|
49 |
+
2025-01-14 10:08:03,453 agent.py[line:84] INFO || Time: Jan14-10:08:03
|
50 |
+
2025-01-14 10:08:03,453 agent.py[line:85] INFO || User: 去哪个地方游玩时应该品尝当地的特色美食图片中的食物?
|
51 |
+
2025-01-14 10:08:03,706 agent.py[line:91] INFO || image save path: /root/InternVL2-Tutorial/outputs/uploaded/0.jpg
|
52 |
+
2025-01-14 10:08:21,808 agent.py[line:103] INFO || generated text =
|
53 |
+
广东,图中的菜是鸡蛋肠粉
|
outputs/logs/Jan14_11-23-28.txt
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2025-01-14 11:23:29,817 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
2 |
+
2025-01-14 11:23:29,817 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
3 |
+
2025-01-14 11:23:29,817 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
4 |
+
2025-01-14 11:23:29,817 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
5 |
+
2025-01-14 11:23:29,838 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
6 |
+
2025-01-14 11:23:29,838 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
7 |
+
2025-01-14 11:23:29,838 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
8 |
+
2025-01-14 11:23:29,838 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
9 |
+
2025-01-14 11:23:30,712 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
10 |
+
2025-01-14 11:23:30,713 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
11 |
+
2025-01-14 11:23:30,713 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
12 |
+
2025-01-14 11:23:30,713 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
13 |
+
2025-01-14 11:23:38,556 modeling_internvl_chat.py[line:54] INFO || num_image_token: 256
|
14 |
+
2025-01-14 11:23:38,556 modeling_internvl_chat.py[line:55] INFO || ps_version: v2
|
15 |
+
2025-01-14 11:23:43,589 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
16 |
+
2025-01-14 11:23:43,593 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
17 |
+
2025-01-14 11:23:43,593 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
18 |
+
2025-01-14 11:23:43,593 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
19 |
+
2025-01-14 11:23:43,778 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
20 |
+
2025-01-14 11:23:43,778 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
21 |
+
2025-01-14 11:23:43,778 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
22 |
+
2025-01-14 11:23:43,778 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
23 |
+
2025-01-14 11:23:43,787 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
24 |
+
2025-01-14 11:23:43,787 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
25 |
+
2025-01-14 11:23:43,787 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
26 |
+
2025-01-14 11:23:43,787 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
27 |
+
2025-01-14 11:23:43,798 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
28 |
+
2025-01-14 11:23:43,798 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
29 |
+
2025-01-14 11:23:43,798 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
30 |
+
2025-01-14 11:23:43,798 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
31 |
+
2025-01-14 11:23:43,817 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
32 |
+
2025-01-14 11:23:43,817 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
33 |
+
2025-01-14 11:23:43,818 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
34 |
+
2025-01-14 11:23:43,818 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
35 |
+
2025-01-14 11:23:43,834 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
36 |
+
2025-01-14 11:23:43,834 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
37 |
+
2025-01-14 11:23:43,834 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
38 |
+
2025-01-14 11:23:43,834 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
39 |
+
2025-01-14 11:23:43,846 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
40 |
+
2025-01-14 11:23:43,846 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
41 |
+
2025-01-14 11:23:43,846 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
42 |
+
2025-01-14 11:23:43,846 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
43 |
+
2025-01-14 11:23:51,555 _client.py[line:1025] INFO || HTTP Request: GET https://checkip.amazonaws.com/ "HTTP/1.1 200 "
|
44 |
+
2025-01-14 11:23:52,073 _client.py[line:1025] INFO || HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
|
45 |
+
2025-01-14 11:23:57,909 _client.py[line:1025] INFO || HTTP Request: GET http://127.0.0.1:1096/startup-events "HTTP/1.1 200 OK"
|
46 |
+
2025-01-14 11:23:58,092 _client.py[line:1025] INFO || HTTP Request: HEAD http://127.0.0.1:1096/ "HTTP/1.1 200 OK"
|
47 |
+
2025-01-14 11:23:58,906 _client.py[line:1025] INFO || HTTP Request: GET https://api.gradio.app/v2/tunnel-request "HTTP/1.1 200 OK"
|
outputs/logs/Jan14_11-28-57.txt
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2025-01-14 11:28:58,580 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
2 |
+
2025-01-14 11:28:58,580 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
3 |
+
2025-01-14 11:28:58,580 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
4 |
+
2025-01-14 11:28:58,580 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
5 |
+
2025-01-14 11:28:58,594 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
6 |
+
2025-01-14 11:28:58,594 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
7 |
+
2025-01-14 11:28:58,594 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
8 |
+
2025-01-14 11:28:58,594 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
9 |
+
2025-01-14 11:28:59,395 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
10 |
+
2025-01-14 11:28:59,395 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
11 |
+
2025-01-14 11:28:59,395 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
12 |
+
2025-01-14 11:28:59,395 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
13 |
+
2025-01-14 11:29:07,303 modeling_internvl_chat.py[line:54] INFO || num_image_token: 256
|
14 |
+
2025-01-14 11:29:07,304 modeling_internvl_chat.py[line:55] INFO || ps_version: v2
|
15 |
+
2025-01-14 11:29:11,003 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
16 |
+
2025-01-14 11:29:11,003 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
17 |
+
2025-01-14 11:29:11,003 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
18 |
+
2025-01-14 11:29:11,003 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
19 |
+
2025-01-14 11:29:11,211 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
20 |
+
2025-01-14 11:29:11,211 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
21 |
+
2025-01-14 11:29:11,211 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
22 |
+
2025-01-14 11:29:11,211 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
23 |
+
2025-01-14 11:29:11,225 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
24 |
+
2025-01-14 11:29:11,225 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
25 |
+
2025-01-14 11:29:11,225 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
26 |
+
2025-01-14 11:29:11,225 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
27 |
+
2025-01-14 11:29:11,240 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
28 |
+
2025-01-14 11:29:11,240 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
29 |
+
2025-01-14 11:29:11,240 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
30 |
+
2025-01-14 11:29:11,240 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
31 |
+
2025-01-14 11:29:11,254 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
32 |
+
2025-01-14 11:29:11,254 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
33 |
+
2025-01-14 11:29:11,254 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
34 |
+
2025-01-14 11:29:11,254 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
35 |
+
2025-01-14 11:29:11,265 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
36 |
+
2025-01-14 11:29:11,265 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
37 |
+
2025-01-14 11:29:11,265 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
38 |
+
2025-01-14 11:29:11,265 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
39 |
+
2025-01-14 11:29:11,291 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
|
40 |
+
2025-01-14 11:29:11,291 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
|
41 |
+
2025-01-14 11:29:11,291 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
|
42 |
+
2025-01-14 11:29:11,291 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
|
43 |
+
2025-01-14 11:29:18,842 _client.py[line:1025] INFO || HTTP Request: GET https://checkip.amazonaws.com/ "HTTP/1.1 200 "
|
44 |
+
2025-01-14 11:29:19,239 _client.py[line:1025] INFO || HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
|
45 |
+
2025-01-14 11:29:19,866 _client.py[line:1025] INFO || HTTP Request: GET http://127.0.0.1:1096/startup-events "HTTP/1.1 200 OK"
|
46 |
+
2025-01-14 11:29:19,979 _client.py[line:1025] INFO || HTTP Request: HEAD http://127.0.0.1:1096/ "HTTP/1.1 200 OK"
|
47 |
+
2025-01-14 11:29:20,806 _client.py[line:1025] INFO || HTTP Request: GET https://api.gradio.app/v2/tunnel-request "HTTP/1.1 200 OK"
|
outputs/uploaded/0.jpg
ADDED
process_food.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
input_path = "/root/huggingface/FoodieQA/FoodieQA/sivqa_tidy.json" # sivqa_tidy.json所在位置
|
3 |
+
output_path = "/root/huggingface/FoodieQA/FoodieQA/sivqa_llava.json" # 输出文件位置
|
4 |
+
|
5 |
+
with open(input_path, 'r', encoding='utf-8') as f:
|
6 |
+
foodqa = json.load(f)
|
7 |
+
|
8 |
+
llava_format = []
|
9 |
+
for data in foodqa:
|
10 |
+
llava_format.append({
|
11 |
+
"image": data['food_meta']['food_file'],
|
12 |
+
"conversations": [
|
13 |
+
{
|
14 |
+
"from": "human",
|
15 |
+
"value": data['question']+"\n<image>"
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"from": "gpt",
|
19 |
+
"value": data['choices'][int(data['answer'])] + ",图中的菜是"+ data['food_meta']['food_name']
|
20 |
+
}
|
21 |
+
]
|
22 |
+
})
|
23 |
+
|
24 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
25 |
+
json.dump(llava_format, f, indent=4, ensure_ascii=False)
|
requirements.txt
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.33.0
|
2 |
+
addict==2.4.0
|
3 |
+
aiofiles==23.2.1
|
4 |
+
aiohttp==3.9.5
|
5 |
+
aiosignal==1.3.1
|
6 |
+
altair==5.3.0
|
7 |
+
annotated-types==0.7.0
|
8 |
+
anyio==4.4.0
|
9 |
+
asttokens==2.4.1
|
10 |
+
async-timeout==4.0.3
|
11 |
+
attrs==23.2.0
|
12 |
+
bitsandbytes==0.43.1
|
13 |
+
blinker==1.8.2
|
14 |
+
cachetools==5.4.0
|
15 |
+
click==8.1.7
|
16 |
+
comm==0.2.2
|
17 |
+
contourpy==1.2.1
|
18 |
+
cycler==0.12.1
|
19 |
+
datasets==2.19.2
|
20 |
+
debugpy==1.8.1
|
21 |
+
decorator==5.1.1
|
22 |
+
dill==0.3.8
|
23 |
+
distro==1.9.0
|
24 |
+
dnspython==2.6.1
|
25 |
+
einops==0.6.1
|
26 |
+
einops-exts==0.0.4
|
27 |
+
email_validator==2.1.1
|
28 |
+
exceptiongroup==1.2.1
|
29 |
+
executing==2.0.1
|
30 |
+
fastapi==0.115.6
|
31 |
+
fastapi-cli==0.0.4
|
32 |
+
ffmpy==0.3.2
|
33 |
+
fire==0.6.0
|
34 |
+
fonttools==4.53.0
|
35 |
+
frozenlist==1.4.1
|
36 |
+
fsspec==2024.3.1
|
37 |
+
gitdb==4.0.11
|
38 |
+
GitPython==3.1.43
|
39 |
+
gradio==4.44.1
|
40 |
+
gradio_client==1.3.0
|
41 |
+
grpcio==1.64.1
|
42 |
+
h11==0.14.0
|
43 |
+
httpcore==1.0.7
|
44 |
+
httptools==0.6.1
|
45 |
+
httpx==0.28.1
|
46 |
+
huggingface-hub==0.26.5
|
47 |
+
importlib_metadata==7.1.0
|
48 |
+
importlib_resources==6.4.0
|
49 |
+
ipykernel==6.29.4
|
50 |
+
ipython==8.25.0
|
51 |
+
jedi==0.19.1
|
52 |
+
jiter==0.5.0
|
53 |
+
joblib==1.4.2
|
54 |
+
jsonschema==4.22.0
|
55 |
+
jsonschema-specifications==2023.12.1
|
56 |
+
jupyter_client==8.6.2
|
57 |
+
jupyter_core==5.7.2
|
58 |
+
kiwisolver==1.4.5
|
59 |
+
lmdeploy==0.5.3
|
60 |
+
markdown-it-py==3.0.0
|
61 |
+
markdown2==2.4.13
|
62 |
+
matplotlib==3.9.0
|
63 |
+
matplotlib-inline==0.1.7
|
64 |
+
mdurl==0.1.2
|
65 |
+
mkl-service==2.4.0
|
66 |
+
mmengine-lite==0.10.4
|
67 |
+
multidict==6.0.5
|
68 |
+
multiprocess==0.70.16
|
69 |
+
nest-asyncio==1.6.0
|
70 |
+
nvidia-cublas-cu12==12.5.2.13
|
71 |
+
nvidia-cuda-runtime-cu12==12.5.39
|
72 |
+
nvidia-curand-cu12==10.3.6.39
|
73 |
+
nvidia-nccl-cu12==2.21.5
|
74 |
+
openai==1.58.1
|
75 |
+
orjson==3.10.5
|
76 |
+
packaging==24.1
|
77 |
+
pandas==2.2.2
|
78 |
+
parso==0.8.4
|
79 |
+
peft==0.9.0
|
80 |
+
pexpect==4.9.0
|
81 |
+
platformdirs==4.2.2
|
82 |
+
prompt_toolkit==3.0.47
|
83 |
+
protobuf==4.25.3
|
84 |
+
psutil==5.9.8
|
85 |
+
ptyprocess==0.7.0
|
86 |
+
pure-eval==0.2.2
|
87 |
+
pyarrow==16.1.0
|
88 |
+
pyarrow-hotfix==0.6
|
89 |
+
pybind11==2.12.0
|
90 |
+
pydantic==2.7.4
|
91 |
+
pydantic_core==2.18.4
|
92 |
+
pydeck==0.9.1
|
93 |
+
pydub==0.25.1
|
94 |
+
Pygments==2.18.0
|
95 |
+
pynvml==11.5.0
|
96 |
+
pyparsing==3.1.2
|
97 |
+
python-dateutil==2.9.0.post0
|
98 |
+
python-dotenv==1.0.1
|
99 |
+
python-multipart==0.0.19
|
100 |
+
python-rapidjson==1.17
|
101 |
+
pytz==2024.1
|
102 |
+
pyzmq==26.0.3
|
103 |
+
referencing==0.35.1
|
104 |
+
regex==2024.5.15
|
105 |
+
requests==2.32.3
|
106 |
+
rich==13.7.1
|
107 |
+
rpds-py==0.18.1
|
108 |
+
ruff==0.4.9
|
109 |
+
safehttpx==0.1.6
|
110 |
+
safetensors==0.4.3
|
111 |
+
scikit-learn==1.2.2
|
112 |
+
scipy==1.13.1
|
113 |
+
semantic-version==2.10.0
|
114 |
+
sentencepiece==0.1.99
|
115 |
+
shellingham==1.5.4
|
116 |
+
shortuuid==1.0.13
|
117 |
+
six==1.16.0
|
118 |
+
smmap==5.0.1
|
119 |
+
sniffio==1.3.1
|
120 |
+
stack-data==0.6.3
|
121 |
+
starlette==0.41.3
|
122 |
+
streamlit==1.37.0
|
123 |
+
svgwrite==1.4.3
|
124 |
+
tenacity==8.5.0
|
125 |
+
termcolor==2.4.0
|
126 |
+
threadpoolctl==3.5.0
|
127 |
+
tiktoken==0.7.0
|
128 |
+
timm==1.0.8
|
129 |
+
tokenizers==0.15.1
|
130 |
+
toml==0.10.2
|
131 |
+
tomli==2.0.1
|
132 |
+
tomlkit==0.12.0
|
133 |
+
toolz==0.12.1
|
134 |
+
torch==2.1.2
|
135 |
+
torchaudio==2.1.2
|
136 |
+
torchvision==0.16.2
|
137 |
+
tornado==6.4.1
|
138 |
+
tqdm==4.66.4
|
139 |
+
traitlets==5.14.3
|
140 |
+
transformers==4.39.3
|
141 |
+
transformers-stream-generator==0.0.5
|
142 |
+
triton==2.1.0
|
143 |
+
tritonclient==2.46.0
|
144 |
+
typer==0.12.3
|
145 |
+
typing_extensions==4.12.2
|
146 |
+
tzdata==2024.1
|
147 |
+
ujson==5.10.0
|
148 |
+
uvicorn==0.30.1
|
149 |
+
uvloop==0.19.0
|
150 |
+
watchdog==4.0.1
|
151 |
+
watchfiles==0.22.0
|
152 |
+
wavedrom==2.0.3.post3
|
153 |
+
wcwidth==0.2.13
|
154 |
+
websockets==11.0.3
|
155 |
+
xxhash==3.4.1
|
156 |
+
yapf==0.40.2
|
157 |
+
yarl==1.9.4
|
158 |
+
zipp==3.19.2
|
utils.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
from datetime import datetime
|
5 |
+
|
6 |
+
|
7 |
+
def load_json(file_name: str):
|
8 |
+
if isinstance(file_name, str) and file_name.endswith("json"):
|
9 |
+
with open(file_name, 'r') as file:
|
10 |
+
data = json.load(file)
|
11 |
+
else:
|
12 |
+
raise ValueError("The file path you passed in is not a json file path.")
|
13 |
+
|
14 |
+
return data
|
15 |
+
|
16 |
+
def init_logger(outputs_dir):
|
17 |
+
current_time = datetime.now().strftime("%b%d_%H-%M-%S")
|
18 |
+
os.makedirs(os.path.join(outputs_dir, "logs"), exist_ok=True)
|
19 |
+
log_path = os.path.join(outputs_dir, "logs", "{}.txt".format(current_time))
|
20 |
+
logging.basicConfig(
|
21 |
+
level=logging.INFO,
|
22 |
+
format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s || %(message)s",
|
23 |
+
handlers=[logging.StreamHandler(), logging.FileHandler(log_path)],
|
24 |
+
)
|
25 |
+
|
xtuner_config/.ipynb_checkpoints/internvl_v2_internlm2_2b_lora_finetune_food-checkpoint.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
|
3 |
+
LoggerHook, ParamSchedulerHook)
|
4 |
+
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
|
5 |
+
from peft import LoraConfig
|
6 |
+
from torch.optim import AdamW
|
7 |
+
from transformers import AutoTokenizer
|
8 |
+
|
9 |
+
from xtuner.dataset import InternVL_V1_5_Dataset
|
10 |
+
from xtuner.dataset.collate_fns import default_collate_fn
|
11 |
+
from xtuner.dataset.samplers import LengthGroupedSampler
|
12 |
+
from xtuner.engine.hooks import DatasetInfoHook
|
13 |
+
from xtuner.engine.runner import TrainLoop
|
14 |
+
from xtuner.model import InternVL_V1_5
|
15 |
+
from xtuner.utils import PROMPT_TEMPLATE
|
16 |
+
|
17 |
+
#######################################################################
|
18 |
+
# PART 1 Settings #
|
19 |
+
#######################################################################
|
20 |
+
# Model
|
21 |
+
path = '/root/share/new_models/OpenGVLab/InternVL2-2B'
|
22 |
+
|
23 |
+
# Data
|
24 |
+
data_root = '/root/share/datasets/FoodieQA/' # your data path
|
25 |
+
data_path = data_root + 'sivqa_llava.json'
|
26 |
+
image_folder = data_root # your image folder path
|
27 |
+
prompt_template = PROMPT_TEMPLATE.internlm2_chat
|
28 |
+
max_length = 8192
|
29 |
+
|
30 |
+
# Scheduler & Optimizer
|
31 |
+
batch_size = 4 # per_device
|
32 |
+
accumulative_counts = 2
|
33 |
+
dataloader_num_workers = 4
|
34 |
+
max_epochs = 10
|
35 |
+
optim_type = AdamW
|
36 |
+
# official 1024 -> 4e-5
|
37 |
+
# lr = 1e-6
|
38 |
+
lr = 3e-5
|
39 |
+
betas = (0.9, 0.999)
|
40 |
+
weight_decay = 0.05
|
41 |
+
max_norm = 1 # grad clip
|
42 |
+
warmup_ratio = 0.03
|
43 |
+
|
44 |
+
# Save
|
45 |
+
save_steps = 64
|
46 |
+
save_total_limit = -1 # Maximum checkpoints to keep (-1 means unlimited)
|
47 |
+
|
48 |
+
#######################################################################
|
49 |
+
# PART 2 Model & Tokenizer & Image Processor #
|
50 |
+
#######################################################################
|
51 |
+
model = dict(
|
52 |
+
type=InternVL_V1_5,
|
53 |
+
model_path=path,
|
54 |
+
freeze_llm=True,
|
55 |
+
freeze_visual_encoder=True,
|
56 |
+
# comment the following lines if you don't want to use Lora in llm
|
57 |
+
llm_lora=dict(
|
58 |
+
type=LoraConfig,
|
59 |
+
r=128,
|
60 |
+
lora_alpha=256,
|
61 |
+
lora_dropout=0.05,
|
62 |
+
target_modules=None,
|
63 |
+
task_type='CAUSAL_LM'),
|
64 |
+
# uncomment the following lines if you don't want to use Lora in visual encoder # noqa
|
65 |
+
# visual_encoder_lora=dict(
|
66 |
+
# type=LoraConfig, r=64, lora_alpha=16, lora_dropout=0.05,
|
67 |
+
# target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'])
|
68 |
+
)
|
69 |
+
|
70 |
+
#######################################################################
|
71 |
+
# PART 3 Dataset & Dataloader #
|
72 |
+
#######################################################################
|
73 |
+
llava_dataset = dict(
|
74 |
+
type=InternVL_V1_5_Dataset,
|
75 |
+
model_path=path,
|
76 |
+
data_paths=data_path,
|
77 |
+
image_folders=image_folder,
|
78 |
+
template=prompt_template,
|
79 |
+
max_length=max_length)
|
80 |
+
|
81 |
+
train_dataloader = dict(
|
82 |
+
batch_size=batch_size,
|
83 |
+
num_workers=dataloader_num_workers,
|
84 |
+
dataset=llava_dataset,
|
85 |
+
sampler=dict(
|
86 |
+
type=LengthGroupedSampler,
|
87 |
+
length_property='modality_length',
|
88 |
+
per_device_batch_size=batch_size * accumulative_counts),
|
89 |
+
collate_fn=dict(type=default_collate_fn))
|
90 |
+
|
91 |
+
#######################################################################
|
92 |
+
# PART 4 Scheduler & Optimizer #
|
93 |
+
#######################################################################
|
94 |
+
# optimizer
|
95 |
+
optim_wrapper = dict(
|
96 |
+
type=AmpOptimWrapper,
|
97 |
+
optimizer=dict(
|
98 |
+
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
|
99 |
+
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
|
100 |
+
accumulative_counts=accumulative_counts,
|
101 |
+
loss_scale='dynamic',
|
102 |
+
dtype='float16')
|
103 |
+
|
104 |
+
# learning policy
|
105 |
+
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
|
106 |
+
param_scheduler = [
|
107 |
+
dict(
|
108 |
+
type=LinearLR,
|
109 |
+
start_factor=1e-5,
|
110 |
+
by_epoch=True,
|
111 |
+
begin=0,
|
112 |
+
end=warmup_ratio * max_epochs,
|
113 |
+
convert_to_iter_based=True),
|
114 |
+
dict(
|
115 |
+
type=CosineAnnealingLR,
|
116 |
+
eta_min=0.0,
|
117 |
+
by_epoch=True,
|
118 |
+
begin=warmup_ratio * max_epochs,
|
119 |
+
end=max_epochs,
|
120 |
+
convert_to_iter_based=True)
|
121 |
+
]
|
122 |
+
|
123 |
+
# train, val, test setting
|
124 |
+
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
|
125 |
+
|
126 |
+
#######################################################################
|
127 |
+
# PART 5 Runtime #
|
128 |
+
#######################################################################
|
129 |
+
# Log the dialogue periodically during the training process, optional
|
130 |
+
tokenizer = dict(
|
131 |
+
type=AutoTokenizer.from_pretrained,
|
132 |
+
pretrained_model_name_or_path=path,
|
133 |
+
trust_remote_code=True)
|
134 |
+
|
135 |
+
custom_hooks = [
|
136 |
+
dict(type=DatasetInfoHook, tokenizer=tokenizer),
|
137 |
+
]
|
138 |
+
|
139 |
+
# configure default hooks
|
140 |
+
default_hooks = dict(
|
141 |
+
# record the time of every iteration.
|
142 |
+
timer=dict(type=IterTimerHook),
|
143 |
+
# print log every 10 iterations.
|
144 |
+
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
|
145 |
+
# enable the parameter scheduler.
|
146 |
+
param_scheduler=dict(type=ParamSchedulerHook),
|
147 |
+
# save checkpoint per `save_steps`.
|
148 |
+
checkpoint=dict(
|
149 |
+
type=CheckpointHook,
|
150 |
+
save_optimizer=False,
|
151 |
+
by_epoch=False,
|
152 |
+
interval=save_steps,
|
153 |
+
max_keep_ckpts=save_total_limit),
|
154 |
+
# set sampler seed in distributed evrionment.
|
155 |
+
sampler_seed=dict(type=DistSamplerSeedHook),
|
156 |
+
)
|
157 |
+
|
158 |
+
# configure environment
|
159 |
+
env_cfg = dict(
|
160 |
+
# whether to enable cudnn benchmark
|
161 |
+
cudnn_benchmark=False,
|
162 |
+
# set multi process parameters
|
163 |
+
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
164 |
+
# set distributed parameters
|
165 |
+
dist_cfg=dict(backend='nccl'),
|
166 |
+
)
|
167 |
+
|
168 |
+
# set visualizer
|
169 |
+
visualizer = None
|
170 |
+
|
171 |
+
# set log level
|
172 |
+
log_level = 'INFO'
|
173 |
+
|
174 |
+
# load from which checkpoint
|
175 |
+
load_from = None
|
176 |
+
|
177 |
+
# whether to resume training from the loaded checkpoint
|
178 |
+
resume = False
|
179 |
+
|
180 |
+
# Defaults to use random seed and disable `deterministic`
|
181 |
+
randomness = dict(seed=None, deterministic=False)
|
182 |
+
|
183 |
+
# set log processor
|
184 |
+
log_processor = dict(by_epoch=False)
|
xtuner_config/internvl_v2_internlm2_2b_lora_finetune_food.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
|
3 |
+
LoggerHook, ParamSchedulerHook)
|
4 |
+
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
|
5 |
+
from peft import LoraConfig
|
6 |
+
from torch.optim import AdamW
|
7 |
+
from transformers import AutoTokenizer
|
8 |
+
|
9 |
+
from xtuner.dataset import InternVL_V1_5_Dataset
|
10 |
+
from xtuner.dataset.collate_fns import default_collate_fn
|
11 |
+
from xtuner.dataset.samplers import LengthGroupedSampler
|
12 |
+
from xtuner.engine.hooks import DatasetInfoHook
|
13 |
+
from xtuner.engine.runner import TrainLoop
|
14 |
+
from xtuner.model import InternVL_V1_5
|
15 |
+
from xtuner.utils import PROMPT_TEMPLATE
|
16 |
+
|
17 |
+
#######################################################################
|
18 |
+
# PART 1 Settings #
|
19 |
+
#######################################################################
|
20 |
+
# Model
|
21 |
+
path = '/root/share/new_models/OpenGVLab/InternVL2-2B'
|
22 |
+
|
23 |
+
# Data
|
24 |
+
data_root = '/root/share/datasets/FoodieQA/' # your data path
|
25 |
+
data_path = data_root + 'sivqa_llava.json'
|
26 |
+
image_folder = data_root # your image folder path
|
27 |
+
prompt_template = PROMPT_TEMPLATE.internlm2_chat
|
28 |
+
max_length = 8192
|
29 |
+
|
30 |
+
# Scheduler & Optimizer
|
31 |
+
batch_size = 4 # per_device
|
32 |
+
accumulative_counts = 2
|
33 |
+
dataloader_num_workers = 4
|
34 |
+
max_epochs = 10
|
35 |
+
optim_type = AdamW
|
36 |
+
# official 1024 -> 4e-5
|
37 |
+
# lr = 1e-6
|
38 |
+
lr = 3e-5
|
39 |
+
betas = (0.9, 0.999)
|
40 |
+
weight_decay = 0.05
|
41 |
+
max_norm = 1 # grad clip
|
42 |
+
warmup_ratio = 0.03
|
43 |
+
|
44 |
+
# Save
|
45 |
+
save_steps = 64
|
46 |
+
save_total_limit = -1 # Maximum checkpoints to keep (-1 means unlimited)
|
47 |
+
|
48 |
+
#######################################################################
|
49 |
+
# PART 2 Model & Tokenizer & Image Processor #
|
50 |
+
#######################################################################
|
51 |
+
model = dict(
|
52 |
+
type=InternVL_V1_5,
|
53 |
+
model_path=path,
|
54 |
+
freeze_llm=True,
|
55 |
+
freeze_visual_encoder=True,
|
56 |
+
# comment the following lines if you don't want to use Lora in llm
|
57 |
+
llm_lora=dict(
|
58 |
+
type=LoraConfig,
|
59 |
+
r=128,
|
60 |
+
lora_alpha=256,
|
61 |
+
lora_dropout=0.05,
|
62 |
+
target_modules=None,
|
63 |
+
task_type='CAUSAL_LM'),
|
64 |
+
# uncomment the following lines if you don't want to use Lora in visual encoder # noqa
|
65 |
+
# visual_encoder_lora=dict(
|
66 |
+
# type=LoraConfig, r=64, lora_alpha=16, lora_dropout=0.05,
|
67 |
+
# target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'])
|
68 |
+
)
|
69 |
+
|
70 |
+
#######################################################################
|
71 |
+
# PART 3 Dataset & Dataloader #
|
72 |
+
#######################################################################
|
73 |
+
llava_dataset = dict(
|
74 |
+
type=InternVL_V1_5_Dataset,
|
75 |
+
model_path=path,
|
76 |
+
data_paths=data_path,
|
77 |
+
image_folders=image_folder,
|
78 |
+
template=prompt_template,
|
79 |
+
max_length=max_length)
|
80 |
+
|
81 |
+
train_dataloader = dict(
|
82 |
+
batch_size=batch_size,
|
83 |
+
num_workers=dataloader_num_workers,
|
84 |
+
dataset=llava_dataset,
|
85 |
+
sampler=dict(
|
86 |
+
type=LengthGroupedSampler,
|
87 |
+
length_property='modality_length',
|
88 |
+
per_device_batch_size=batch_size * accumulative_counts),
|
89 |
+
collate_fn=dict(type=default_collate_fn))
|
90 |
+
|
91 |
+
#######################################################################
|
92 |
+
# PART 4 Scheduler & Optimizer #
|
93 |
+
#######################################################################
|
94 |
+
# optimizer
|
95 |
+
optim_wrapper = dict(
|
96 |
+
type=AmpOptimWrapper,
|
97 |
+
optimizer=dict(
|
98 |
+
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
|
99 |
+
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
|
100 |
+
accumulative_counts=accumulative_counts,
|
101 |
+
loss_scale='dynamic',
|
102 |
+
dtype='float16')
|
103 |
+
|
104 |
+
# learning policy
|
105 |
+
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
|
106 |
+
param_scheduler = [
|
107 |
+
dict(
|
108 |
+
type=LinearLR,
|
109 |
+
start_factor=1e-5,
|
110 |
+
by_epoch=True,
|
111 |
+
begin=0,
|
112 |
+
end=warmup_ratio * max_epochs,
|
113 |
+
convert_to_iter_based=True),
|
114 |
+
dict(
|
115 |
+
type=CosineAnnealingLR,
|
116 |
+
eta_min=0.0,
|
117 |
+
by_epoch=True,
|
118 |
+
begin=warmup_ratio * max_epochs,
|
119 |
+
end=max_epochs,
|
120 |
+
convert_to_iter_based=True)
|
121 |
+
]
|
122 |
+
|
123 |
+
# train, val, test setting
|
124 |
+
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
|
125 |
+
|
126 |
+
#######################################################################
|
127 |
+
# PART 5 Runtime #
|
128 |
+
#######################################################################
|
129 |
+
# Log the dialogue periodically during the training process, optional
|
130 |
+
tokenizer = dict(
|
131 |
+
type=AutoTokenizer.from_pretrained,
|
132 |
+
pretrained_model_name_or_path=path,
|
133 |
+
trust_remote_code=True)
|
134 |
+
|
135 |
+
custom_hooks = [
|
136 |
+
dict(type=DatasetInfoHook, tokenizer=tokenizer),
|
137 |
+
]
|
138 |
+
|
139 |
+
# configure default hooks
|
140 |
+
default_hooks = dict(
|
141 |
+
# record the time of every iteration.
|
142 |
+
timer=dict(type=IterTimerHook),
|
143 |
+
# print log every 10 iterations.
|
144 |
+
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
|
145 |
+
# enable the parameter scheduler.
|
146 |
+
param_scheduler=dict(type=ParamSchedulerHook),
|
147 |
+
# save checkpoint per `save_steps`.
|
148 |
+
checkpoint=dict(
|
149 |
+
type=CheckpointHook,
|
150 |
+
save_optimizer=False,
|
151 |
+
by_epoch=False,
|
152 |
+
interval=save_steps,
|
153 |
+
max_keep_ckpts=save_total_limit),
|
154 |
+
# set sampler seed in distributed evrionment.
|
155 |
+
sampler_seed=dict(type=DistSamplerSeedHook),
|
156 |
+
)
|
157 |
+
|
158 |
+
# configure environment
|
159 |
+
env_cfg = dict(
|
160 |
+
# whether to enable cudnn benchmark
|
161 |
+
cudnn_benchmark=False,
|
162 |
+
# set multi process parameters
|
163 |
+
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
164 |
+
# set distributed parameters
|
165 |
+
dist_cfg=dict(backend='nccl'),
|
166 |
+
)
|
167 |
+
|
168 |
+
# set visualizer
|
169 |
+
visualizer = None
|
170 |
+
|
171 |
+
# set log level
|
172 |
+
log_level = 'INFO'
|
173 |
+
|
174 |
+
# load from which checkpoint
|
175 |
+
load_from = None
|
176 |
+
|
177 |
+
# whether to resume training from the loaded checkpoint
|
178 |
+
resume = False
|
179 |
+
|
180 |
+
# Defaults to use random seed and disable `deterministic`
|
181 |
+
randomness = dict(seed=None, deterministic=False)
|
182 |
+
|
183 |
+
# set log processor
|
184 |
+
log_processor = dict(by_epoch=False)
|