dstars commited on
Commit
12afd35
·
verified ·
1 Parent(s): 4bebcaf

Upload 10 files

Browse files
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+ import torch.backends.cudnn as cudnn
6
+ import gradio as gr
7
+
8
+ from utils import load_json, init_logger
9
+ from demo import ConversationalAgent, CustomTheme
10
+
11
+ FOOD_EXAMPLES = "./demo/food_for_demo.json"
12
+ # MODEL_PATH = "/root/share/new_models/OpenGVLab/InternVL2-2B"
13
+ MODEL_PATH = "./lr35_ep10"
14
+ OUTPUT_PATH = "./outputs"
15
+
16
+ def setup_seeds():
17
+ seed = 42
18
+
19
+ random.seed(seed)
20
+ np.random.seed(seed)
21
+ torch.manual_seed(seed)
22
+
23
+ cudnn.benchmark = False
24
+ cudnn.deterministic = True
25
+
26
+
27
+ def main():
28
+ setup_seeds()
29
+ # logging
30
+ init_logger(OUTPUT_PATH)
31
+ # food examples
32
+ food_examples = load_json(FOOD_EXAMPLES)
33
+
34
+ agent = ConversationalAgent(model_path=MODEL_PATH,
35
+ outputs_dir=OUTPUT_PATH)
36
+
37
+ theme = CustomTheme()
38
+
39
+ titles = [
40
+ """<center><B><font face="Comic Sans MS" size=10>书生大模型实战营</font></B></center>""" ## Kalam:wght@700
41
+ """<center><B><font face="Courier" size=5>「进阶岛」InternVL 多模态模型部署微调实践</font></B></center>"""
42
+ ]
43
+
44
+ language = """Language: 中文 and English"""
45
+ with gr.Blocks(theme) as demo_chatbot:
46
+ for title in titles:
47
+ gr.Markdown(title)
48
+ # gr.Markdown(article)
49
+ gr.Markdown(language)
50
+
51
+ with gr.Row():
52
+ with gr.Column(scale=3):
53
+ start_btn = gr.Button("Start Chat", variant="primary", interactive=True)
54
+ clear_btn = gr.Button("Clear Context", interactive=False)
55
+ image = gr.Image(type="pil", interactive=False)
56
+ upload_btn = gr.Button("🖼️ Upload Image", interactive=False)
57
+
58
+ with gr.Accordion("Generation Settings"):
59
+ top_p = gr.Slider(minimum=0, maximum=1, step=0.1,
60
+ value=0.8,
61
+ interactive=True,
62
+ label='top-p value',
63
+ visible=True)
64
+
65
+ temperature = gr.Slider(minimum=0, maximum=1.5, step=0.1,
66
+ value=0.8,
67
+ interactive=True,
68
+ label='temperature',
69
+ visible=True)
70
+
71
+ with gr.Column(scale=7):
72
+ chat_state = gr.State()
73
+ chatbot = gr.Chatbot(label='InternVL2', height=800, avatar_images=((os.path.join(os.path.dirname(__file__), 'demo/user.png')), (os.path.join(os.path.dirname(__file__), "demo/bot.png"))))
74
+ text_input = gr.Textbox(label='User', placeholder="Please click the <Start Chat> button to start chat!", interactive=False)
75
+ gr.Markdown("### 输入示例")
76
+ def on_text_change(text):
77
+ return gr.update(interactive=True)
78
+ text_input.change(fn=on_text_change, inputs=text_input, outputs=text_input)
79
+ gr.Examples(
80
+ examples=[["图片中的食物通常属于哪个菜系?"],
81
+ ["如果让你简单形容一下品尝图片中的食物的滋味,你会描述它"],
82
+ ["去哪个地方游玩时应该品尝当地的特色美食图片中的食物?"],
83
+ ["食用图片中的食物时,一般它上菜或摆盘时的特点是?"]],
84
+ inputs=[text_input]
85
+ )
86
+
87
+ with gr.Row():
88
+ gr.Markdown("### 食物快捷栏")
89
+ with gr.Row():
90
+ example_xinjiang_food = gr.Examples(examples=food_examples["新疆菜"], inputs=image, label="新疆菜")
91
+ example_sichuan_food = gr.Examples(examples=food_examples["川菜(四川,重庆)"], inputs=image, label="川菜(四川,重庆)")
92
+ example_xibei_food = gr.Examples(examples=food_examples["西北菜 (陕西,甘肃等地)"], inputs=image, label="西北菜 (陕西,甘肃等地)")
93
+ with gr.Row():
94
+ example_guizhou_food = gr.Examples(examples=food_examples["黔菜 (贵州)"], inputs=image, label="黔菜 (贵州)")
95
+ example_jiangsu_food = gr.Examples(examples=food_examples["苏菜(江苏)"], inputs=image, label="苏菜(江苏)")
96
+ example_guangdong_food = gr.Examples(examples=food_examples["粤菜(广东等地)"], inputs=image, label="粤菜(广东等地)")
97
+ with gr.Row():
98
+ example_hunan_food = gr.Examples(examples=food_examples["湘菜(湖南)"], inputs=image, label="湘菜(湖南)")
99
+ example_fujian_food = gr.Examples(examples=food_examples["闽菜(福建)"], inputs=image, label="闽菜(福建)")
100
+ example_zhejiang_food = gr.Examples(examples=food_examples["浙菜(浙江)"], inputs=image, label="浙菜(��江)")
101
+ with gr.Row():
102
+ example_dongbei_food = gr.Examples(examples=food_examples["东北菜 (黑龙江等地)"], inputs=image, label="东北菜 (黑龙江等地)")
103
+
104
+
105
+ start_btn.click(agent.start_chat, [chat_state], [text_input, start_btn, clear_btn, image, upload_btn, chat_state])
106
+ clear_btn.click(agent.restart_chat, [chat_state], [chatbot, text_input, start_btn, clear_btn, image, upload_btn, chat_state], queue=False)
107
+ upload_btn.click(agent.upload_image, [image, chatbot, chat_state], [image, chatbot, chat_state])
108
+ text_input.submit(
109
+ agent.respond,
110
+ inputs=[text_input, image, chatbot, top_p, temperature, chat_state],
111
+ outputs=[text_input, image, chatbot, chat_state]
112
+ )
113
+
114
+ demo_chatbot.launch(share=True, server_name="127.0.0.1", server_port=1096, allowed_paths=['./'])
115
+ demo_chatbot.queue()
116
+
117
+
118
+ if __name__ == "__main__":
119
+ main()
outputs/logs/Jan14_09-50-48.txt ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-01-14 09:50:53,500 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
2
+ 2025-01-14 09:50:53,502 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
3
+ 2025-01-14 09:50:53,502 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
4
+ 2025-01-14 09:50:53,502 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
5
+ 2025-01-14 09:50:53,537 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
6
+ 2025-01-14 09:50:53,537 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
7
+ 2025-01-14 09:50:53,537 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
8
+ 2025-01-14 09:50:53,537 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
9
+ 2025-01-14 09:50:54,641 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
10
+ 2025-01-14 09:50:54,642 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
11
+ 2025-01-14 09:50:54,642 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
12
+ 2025-01-14 09:50:54,642 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
13
+ 2025-01-14 09:51:27,929 modeling_internvl_chat.py[line:54] INFO || num_image_token: 256
14
+ 2025-01-14 09:51:27,930 modeling_internvl_chat.py[line:55] INFO || ps_version: v2
15
+ 2025-01-14 09:53:54,644 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
16
+ 2025-01-14 09:53:54,645 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
17
+ 2025-01-14 09:53:54,646 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
18
+ 2025-01-14 09:53:54,646 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
19
+ 2025-01-14 09:53:55,190 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
20
+ 2025-01-14 09:53:55,190 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
21
+ 2025-01-14 09:53:55,190 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
22
+ 2025-01-14 09:53:55,190 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
23
+ 2025-01-14 09:53:55,199 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
24
+ 2025-01-14 09:53:55,199 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
25
+ 2025-01-14 09:53:55,199 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
26
+ 2025-01-14 09:53:55,199 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
27
+ 2025-01-14 09:53:55,220 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
28
+ 2025-01-14 09:53:55,220 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
29
+ 2025-01-14 09:53:55,220 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
30
+ 2025-01-14 09:53:55,220 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
31
+ 2025-01-14 09:53:55,233 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
32
+ 2025-01-14 09:53:55,233 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
33
+ 2025-01-14 09:53:55,233 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
34
+ 2025-01-14 09:53:55,233 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
35
+ 2025-01-14 09:53:55,241 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
36
+ 2025-01-14 09:53:55,242 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
37
+ 2025-01-14 09:53:55,242 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
38
+ 2025-01-14 09:53:55,242 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
39
+ 2025-01-14 09:53:55,251 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
40
+ 2025-01-14 09:53:55,251 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
41
+ 2025-01-14 09:53:55,251 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
42
+ 2025-01-14 09:53:55,251 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
43
+ 2025-01-14 09:54:03,863 _client.py[line:1025] INFO || HTTP Request: GET https://checkip.amazonaws.com/ "HTTP/1.1 200 "
44
+ 2025-01-14 09:54:03,950 _client.py[line:1025] INFO || HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
45
+ 2025-01-14 09:54:11,354 _client.py[line:1025] INFO || HTTP Request: GET http://127.0.0.1:1096/startup-events "HTTP/1.1 200 OK"
46
+ 2025-01-14 09:54:11,700 _client.py[line:1025] INFO || HTTP Request: HEAD http://127.0.0.1:1096/ "HTTP/1.1 200 OK"
47
+ 2025-01-14 09:54:12,493 _client.py[line:1025] INFO || HTTP Request: GET https://api.gradio.app/v2/tunnel-request "HTTP/1.1 200 OK"
48
+ 2025-01-14 10:07:49,759 agent.py[line:30] INFO || ==============================Start Chat==============================
49
+ 2025-01-14 10:08:03,453 agent.py[line:84] INFO || Time: Jan14-10:08:03
50
+ 2025-01-14 10:08:03,453 agent.py[line:85] INFO || User: 去哪个地方游玩时应该品尝当地的特色美食图片中的食物?
51
+ 2025-01-14 10:08:03,706 agent.py[line:91] INFO || image save path: /root/InternVL2-Tutorial/outputs/uploaded/0.jpg
52
+ 2025-01-14 10:08:21,808 agent.py[line:103] INFO || generated text =
53
+ 广东,图中的菜是鸡蛋肠粉
outputs/logs/Jan14_11-23-28.txt ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-01-14 11:23:29,817 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
2
+ 2025-01-14 11:23:29,817 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
3
+ 2025-01-14 11:23:29,817 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
4
+ 2025-01-14 11:23:29,817 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
5
+ 2025-01-14 11:23:29,838 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
6
+ 2025-01-14 11:23:29,838 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
7
+ 2025-01-14 11:23:29,838 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
8
+ 2025-01-14 11:23:29,838 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
9
+ 2025-01-14 11:23:30,712 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
10
+ 2025-01-14 11:23:30,713 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
11
+ 2025-01-14 11:23:30,713 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
12
+ 2025-01-14 11:23:30,713 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
13
+ 2025-01-14 11:23:38,556 modeling_internvl_chat.py[line:54] INFO || num_image_token: 256
14
+ 2025-01-14 11:23:38,556 modeling_internvl_chat.py[line:55] INFO || ps_version: v2
15
+ 2025-01-14 11:23:43,589 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
16
+ 2025-01-14 11:23:43,593 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
17
+ 2025-01-14 11:23:43,593 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
18
+ 2025-01-14 11:23:43,593 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
19
+ 2025-01-14 11:23:43,778 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
20
+ 2025-01-14 11:23:43,778 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
21
+ 2025-01-14 11:23:43,778 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
22
+ 2025-01-14 11:23:43,778 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
23
+ 2025-01-14 11:23:43,787 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
24
+ 2025-01-14 11:23:43,787 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
25
+ 2025-01-14 11:23:43,787 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
26
+ 2025-01-14 11:23:43,787 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
27
+ 2025-01-14 11:23:43,798 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
28
+ 2025-01-14 11:23:43,798 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
29
+ 2025-01-14 11:23:43,798 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
30
+ 2025-01-14 11:23:43,798 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
31
+ 2025-01-14 11:23:43,817 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
32
+ 2025-01-14 11:23:43,817 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
33
+ 2025-01-14 11:23:43,818 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
34
+ 2025-01-14 11:23:43,818 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
35
+ 2025-01-14 11:23:43,834 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
36
+ 2025-01-14 11:23:43,834 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
37
+ 2025-01-14 11:23:43,834 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
38
+ 2025-01-14 11:23:43,834 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
39
+ 2025-01-14 11:23:43,846 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
40
+ 2025-01-14 11:23:43,846 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
41
+ 2025-01-14 11:23:43,846 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
42
+ 2025-01-14 11:23:43,846 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
43
+ 2025-01-14 11:23:51,555 _client.py[line:1025] INFO || HTTP Request: GET https://checkip.amazonaws.com/ "HTTP/1.1 200 "
44
+ 2025-01-14 11:23:52,073 _client.py[line:1025] INFO || HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
45
+ 2025-01-14 11:23:57,909 _client.py[line:1025] INFO || HTTP Request: GET http://127.0.0.1:1096/startup-events "HTTP/1.1 200 OK"
46
+ 2025-01-14 11:23:58,092 _client.py[line:1025] INFO || HTTP Request: HEAD http://127.0.0.1:1096/ "HTTP/1.1 200 OK"
47
+ 2025-01-14 11:23:58,906 _client.py[line:1025] INFO || HTTP Request: GET https://api.gradio.app/v2/tunnel-request "HTTP/1.1 200 OK"
outputs/logs/Jan14_11-28-57.txt ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-01-14 11:28:58,580 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
2
+ 2025-01-14 11:28:58,580 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
3
+ 2025-01-14 11:28:58,580 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
4
+ 2025-01-14 11:28:58,580 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
5
+ 2025-01-14 11:28:58,594 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
6
+ 2025-01-14 11:28:58,594 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
7
+ 2025-01-14 11:28:58,594 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
8
+ 2025-01-14 11:28:58,594 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
9
+ 2025-01-14 11:28:59,395 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
10
+ 2025-01-14 11:28:59,395 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
11
+ 2025-01-14 11:28:59,395 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
12
+ 2025-01-14 11:28:59,395 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
13
+ 2025-01-14 11:29:07,303 modeling_internvl_chat.py[line:54] INFO || num_image_token: 256
14
+ 2025-01-14 11:29:07,304 modeling_internvl_chat.py[line:55] INFO || ps_version: v2
15
+ 2025-01-14 11:29:11,003 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
16
+ 2025-01-14 11:29:11,003 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
17
+ 2025-01-14 11:29:11,003 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
18
+ 2025-01-14 11:29:11,003 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
19
+ 2025-01-14 11:29:11,211 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
20
+ 2025-01-14 11:29:11,211 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
21
+ 2025-01-14 11:29:11,211 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
22
+ 2025-01-14 11:29:11,211 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
23
+ 2025-01-14 11:29:11,225 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
24
+ 2025-01-14 11:29:11,225 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
25
+ 2025-01-14 11:29:11,225 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
26
+ 2025-01-14 11:29:11,225 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
27
+ 2025-01-14 11:29:11,240 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
28
+ 2025-01-14 11:29:11,240 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
29
+ 2025-01-14 11:29:11,240 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
30
+ 2025-01-14 11:29:11,240 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
31
+ 2025-01-14 11:29:11,254 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
32
+ 2025-01-14 11:29:11,254 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
33
+ 2025-01-14 11:29:11,254 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
34
+ 2025-01-14 11:29:11,254 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
35
+ 2025-01-14 11:29:11,265 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
36
+ 2025-01-14 11:29:11,265 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
37
+ 2025-01-14 11:29:11,265 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
38
+ 2025-01-14 11:29:11,265 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
39
+ 2025-01-14 11:29:11,291 configuration_internvl_chat.py[line:68] INFO || vision_select_layer: -1
40
+ 2025-01-14 11:29:11,291 configuration_internvl_chat.py[line:69] INFO || ps_version: v2
41
+ 2025-01-14 11:29:11,291 configuration_internvl_chat.py[line:70] INFO || min_dynamic_patch: 1
42
+ 2025-01-14 11:29:11,291 configuration_internvl_chat.py[line:71] INFO || max_dynamic_patch: 12
43
+ 2025-01-14 11:29:18,842 _client.py[line:1025] INFO || HTTP Request: GET https://checkip.amazonaws.com/ "HTTP/1.1 200 "
44
+ 2025-01-14 11:29:19,239 _client.py[line:1025] INFO || HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
45
+ 2025-01-14 11:29:19,866 _client.py[line:1025] INFO || HTTP Request: GET http://127.0.0.1:1096/startup-events "HTTP/1.1 200 OK"
46
+ 2025-01-14 11:29:19,979 _client.py[line:1025] INFO || HTTP Request: HEAD http://127.0.0.1:1096/ "HTTP/1.1 200 OK"
47
+ 2025-01-14 11:29:20,806 _client.py[line:1025] INFO || HTTP Request: GET https://api.gradio.app/v2/tunnel-request "HTTP/1.1 200 OK"
outputs/uploaded/0.jpg ADDED
process_food.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ input_path = "/root/huggingface/FoodieQA/FoodieQA/sivqa_tidy.json" # sivqa_tidy.json所在位置
3
+ output_path = "/root/huggingface/FoodieQA/FoodieQA/sivqa_llava.json" # 输出文件位置
4
+
5
+ with open(input_path, 'r', encoding='utf-8') as f:
6
+ foodqa = json.load(f)
7
+
8
+ llava_format = []
9
+ for data in foodqa:
10
+ llava_format.append({
11
+ "image": data['food_meta']['food_file'],
12
+ "conversations": [
13
+ {
14
+ "from": "human",
15
+ "value": data['question']+"\n<image>"
16
+ },
17
+ {
18
+ "from": "gpt",
19
+ "value": data['choices'][int(data['answer'])] + ",图中的菜是"+ data['food_meta']['food_name']
20
+ }
21
+ ]
22
+ })
23
+
24
+ with open(output_path, 'w', encoding='utf-8') as f:
25
+ json.dump(llava_format, f, indent=4, ensure_ascii=False)
requirements.txt ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.33.0
2
+ addict==2.4.0
3
+ aiofiles==23.2.1
4
+ aiohttp==3.9.5
5
+ aiosignal==1.3.1
6
+ altair==5.3.0
7
+ annotated-types==0.7.0
8
+ anyio==4.4.0
9
+ asttokens==2.4.1
10
+ async-timeout==4.0.3
11
+ attrs==23.2.0
12
+ bitsandbytes==0.43.1
13
+ blinker==1.8.2
14
+ cachetools==5.4.0
15
+ click==8.1.7
16
+ comm==0.2.2
17
+ contourpy==1.2.1
18
+ cycler==0.12.1
19
+ datasets==2.19.2
20
+ debugpy==1.8.1
21
+ decorator==5.1.1
22
+ dill==0.3.8
23
+ distro==1.9.0
24
+ dnspython==2.6.1
25
+ einops==0.6.1
26
+ einops-exts==0.0.4
27
+ email_validator==2.1.1
28
+ exceptiongroup==1.2.1
29
+ executing==2.0.1
30
+ fastapi==0.115.6
31
+ fastapi-cli==0.0.4
32
+ ffmpy==0.3.2
33
+ fire==0.6.0
34
+ fonttools==4.53.0
35
+ frozenlist==1.4.1
36
+ fsspec==2024.3.1
37
+ gitdb==4.0.11
38
+ GitPython==3.1.43
39
+ gradio==4.44.1
40
+ gradio_client==1.3.0
41
+ grpcio==1.64.1
42
+ h11==0.14.0
43
+ httpcore==1.0.7
44
+ httptools==0.6.1
45
+ httpx==0.28.1
46
+ huggingface-hub==0.26.5
47
+ importlib_metadata==7.1.0
48
+ importlib_resources==6.4.0
49
+ ipykernel==6.29.4
50
+ ipython==8.25.0
51
+ jedi==0.19.1
52
+ jiter==0.5.0
53
+ joblib==1.4.2
54
+ jsonschema==4.22.0
55
+ jsonschema-specifications==2023.12.1
56
+ jupyter_client==8.6.2
57
+ jupyter_core==5.7.2
58
+ kiwisolver==1.4.5
59
+ lmdeploy==0.5.3
60
+ markdown-it-py==3.0.0
61
+ markdown2==2.4.13
62
+ matplotlib==3.9.0
63
+ matplotlib-inline==0.1.7
64
+ mdurl==0.1.2
65
+ mkl-service==2.4.0
66
+ mmengine-lite==0.10.4
67
+ multidict==6.0.5
68
+ multiprocess==0.70.16
69
+ nest-asyncio==1.6.0
70
+ nvidia-cublas-cu12==12.5.2.13
71
+ nvidia-cuda-runtime-cu12==12.5.39
72
+ nvidia-curand-cu12==10.3.6.39
73
+ nvidia-nccl-cu12==2.21.5
74
+ openai==1.58.1
75
+ orjson==3.10.5
76
+ packaging==24.1
77
+ pandas==2.2.2
78
+ parso==0.8.4
79
+ peft==0.9.0
80
+ pexpect==4.9.0
81
+ platformdirs==4.2.2
82
+ prompt_toolkit==3.0.47
83
+ protobuf==4.25.3
84
+ psutil==5.9.8
85
+ ptyprocess==0.7.0
86
+ pure-eval==0.2.2
87
+ pyarrow==16.1.0
88
+ pyarrow-hotfix==0.6
89
+ pybind11==2.12.0
90
+ pydantic==2.7.4
91
+ pydantic_core==2.18.4
92
+ pydeck==0.9.1
93
+ pydub==0.25.1
94
+ Pygments==2.18.0
95
+ pynvml==11.5.0
96
+ pyparsing==3.1.2
97
+ python-dateutil==2.9.0.post0
98
+ python-dotenv==1.0.1
99
+ python-multipart==0.0.19
100
+ python-rapidjson==1.17
101
+ pytz==2024.1
102
+ pyzmq==26.0.3
103
+ referencing==0.35.1
104
+ regex==2024.5.15
105
+ requests==2.32.3
106
+ rich==13.7.1
107
+ rpds-py==0.18.1
108
+ ruff==0.4.9
109
+ safehttpx==0.1.6
110
+ safetensors==0.4.3
111
+ scikit-learn==1.2.2
112
+ scipy==1.13.1
113
+ semantic-version==2.10.0
114
+ sentencepiece==0.1.99
115
+ shellingham==1.5.4
116
+ shortuuid==1.0.13
117
+ six==1.16.0
118
+ smmap==5.0.1
119
+ sniffio==1.3.1
120
+ stack-data==0.6.3
121
+ starlette==0.41.3
122
+ streamlit==1.37.0
123
+ svgwrite==1.4.3
124
+ tenacity==8.5.0
125
+ termcolor==2.4.0
126
+ threadpoolctl==3.5.0
127
+ tiktoken==0.7.0
128
+ timm==1.0.8
129
+ tokenizers==0.15.1
130
+ toml==0.10.2
131
+ tomli==2.0.1
132
+ tomlkit==0.12.0
133
+ toolz==0.12.1
134
+ torch==2.1.2
135
+ torchaudio==2.1.2
136
+ torchvision==0.16.2
137
+ tornado==6.4.1
138
+ tqdm==4.66.4
139
+ traitlets==5.14.3
140
+ transformers==4.39.3
141
+ transformers-stream-generator==0.0.5
142
+ triton==2.1.0
143
+ tritonclient==2.46.0
144
+ typer==0.12.3
145
+ typing_extensions==4.12.2
146
+ tzdata==2024.1
147
+ ujson==5.10.0
148
+ uvicorn==0.30.1
149
+ uvloop==0.19.0
150
+ watchdog==4.0.1
151
+ watchfiles==0.22.0
152
+ wavedrom==2.0.3.post3
153
+ wcwidth==0.2.13
154
+ websockets==11.0.3
155
+ xxhash==3.4.1
156
+ yapf==0.40.2
157
+ yarl==1.9.4
158
+ zipp==3.19.2
utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import logging
4
+ from datetime import datetime
5
+
6
+
7
+ def load_json(file_name: str):
8
+ if isinstance(file_name, str) and file_name.endswith("json"):
9
+ with open(file_name, 'r') as file:
10
+ data = json.load(file)
11
+ else:
12
+ raise ValueError("The file path you passed in is not a json file path.")
13
+
14
+ return data
15
+
16
+ def init_logger(outputs_dir):
17
+ current_time = datetime.now().strftime("%b%d_%H-%M-%S")
18
+ os.makedirs(os.path.join(outputs_dir, "logs"), exist_ok=True)
19
+ log_path = os.path.join(outputs_dir, "logs", "{}.txt".format(current_time))
20
+ logging.basicConfig(
21
+ level=logging.INFO,
22
+ format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s || %(message)s",
23
+ handlers=[logging.StreamHandler(), logging.FileHandler(log_path)],
24
+ )
25
+
xtuner_config/.ipynb_checkpoints/internvl_v2_internlm2_2b_lora_finetune_food-checkpoint.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
3
+ LoggerHook, ParamSchedulerHook)
4
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
5
+ from peft import LoraConfig
6
+ from torch.optim import AdamW
7
+ from transformers import AutoTokenizer
8
+
9
+ from xtuner.dataset import InternVL_V1_5_Dataset
10
+ from xtuner.dataset.collate_fns import default_collate_fn
11
+ from xtuner.dataset.samplers import LengthGroupedSampler
12
+ from xtuner.engine.hooks import DatasetInfoHook
13
+ from xtuner.engine.runner import TrainLoop
14
+ from xtuner.model import InternVL_V1_5
15
+ from xtuner.utils import PROMPT_TEMPLATE
16
+
17
+ #######################################################################
18
+ # PART 1 Settings #
19
+ #######################################################################
20
+ # Model
21
+ path = '/root/share/new_models/OpenGVLab/InternVL2-2B'
22
+
23
+ # Data
24
+ data_root = '/root/share/datasets/FoodieQA/' # your data path
25
+ data_path = data_root + 'sivqa_llava.json'
26
+ image_folder = data_root # your image folder path
27
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
28
+ max_length = 8192
29
+
30
+ # Scheduler & Optimizer
31
+ batch_size = 4 # per_device
32
+ accumulative_counts = 2
33
+ dataloader_num_workers = 4
34
+ max_epochs = 10
35
+ optim_type = AdamW
36
+ # official 1024 -> 4e-5
37
+ # lr = 1e-6
38
+ lr = 3e-5
39
+ betas = (0.9, 0.999)
40
+ weight_decay = 0.05
41
+ max_norm = 1 # grad clip
42
+ warmup_ratio = 0.03
43
+
44
+ # Save
45
+ save_steps = 64
46
+ save_total_limit = -1 # Maximum checkpoints to keep (-1 means unlimited)
47
+
48
+ #######################################################################
49
+ # PART 2 Model & Tokenizer & Image Processor #
50
+ #######################################################################
51
+ model = dict(
52
+ type=InternVL_V1_5,
53
+ model_path=path,
54
+ freeze_llm=True,
55
+ freeze_visual_encoder=True,
56
+ # comment the following lines if you don't want to use Lora in llm
57
+ llm_lora=dict(
58
+ type=LoraConfig,
59
+ r=128,
60
+ lora_alpha=256,
61
+ lora_dropout=0.05,
62
+ target_modules=None,
63
+ task_type='CAUSAL_LM'),
64
+ # uncomment the following lines if you don't want to use Lora in visual encoder # noqa
65
+ # visual_encoder_lora=dict(
66
+ # type=LoraConfig, r=64, lora_alpha=16, lora_dropout=0.05,
67
+ # target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'])
68
+ )
69
+
70
+ #######################################################################
71
+ # PART 3 Dataset & Dataloader #
72
+ #######################################################################
73
+ llava_dataset = dict(
74
+ type=InternVL_V1_5_Dataset,
75
+ model_path=path,
76
+ data_paths=data_path,
77
+ image_folders=image_folder,
78
+ template=prompt_template,
79
+ max_length=max_length)
80
+
81
+ train_dataloader = dict(
82
+ batch_size=batch_size,
83
+ num_workers=dataloader_num_workers,
84
+ dataset=llava_dataset,
85
+ sampler=dict(
86
+ type=LengthGroupedSampler,
87
+ length_property='modality_length',
88
+ per_device_batch_size=batch_size * accumulative_counts),
89
+ collate_fn=dict(type=default_collate_fn))
90
+
91
+ #######################################################################
92
+ # PART 4 Scheduler & Optimizer #
93
+ #######################################################################
94
+ # optimizer
95
+ optim_wrapper = dict(
96
+ type=AmpOptimWrapper,
97
+ optimizer=dict(
98
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
99
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
100
+ accumulative_counts=accumulative_counts,
101
+ loss_scale='dynamic',
102
+ dtype='float16')
103
+
104
+ # learning policy
105
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
106
+ param_scheduler = [
107
+ dict(
108
+ type=LinearLR,
109
+ start_factor=1e-5,
110
+ by_epoch=True,
111
+ begin=0,
112
+ end=warmup_ratio * max_epochs,
113
+ convert_to_iter_based=True),
114
+ dict(
115
+ type=CosineAnnealingLR,
116
+ eta_min=0.0,
117
+ by_epoch=True,
118
+ begin=warmup_ratio * max_epochs,
119
+ end=max_epochs,
120
+ convert_to_iter_based=True)
121
+ ]
122
+
123
+ # train, val, test setting
124
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
125
+
126
+ #######################################################################
127
+ # PART 5 Runtime #
128
+ #######################################################################
129
+ # Log the dialogue periodically during the training process, optional
130
+ tokenizer = dict(
131
+ type=AutoTokenizer.from_pretrained,
132
+ pretrained_model_name_or_path=path,
133
+ trust_remote_code=True)
134
+
135
+ custom_hooks = [
136
+ dict(type=DatasetInfoHook, tokenizer=tokenizer),
137
+ ]
138
+
139
+ # configure default hooks
140
+ default_hooks = dict(
141
+ # record the time of every iteration.
142
+ timer=dict(type=IterTimerHook),
143
+ # print log every 10 iterations.
144
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
145
+ # enable the parameter scheduler.
146
+ param_scheduler=dict(type=ParamSchedulerHook),
147
+ # save checkpoint per `save_steps`.
148
+ checkpoint=dict(
149
+ type=CheckpointHook,
150
+ save_optimizer=False,
151
+ by_epoch=False,
152
+ interval=save_steps,
153
+ max_keep_ckpts=save_total_limit),
154
+ # set sampler seed in distributed evrionment.
155
+ sampler_seed=dict(type=DistSamplerSeedHook),
156
+ )
157
+
158
+ # configure environment
159
+ env_cfg = dict(
160
+ # whether to enable cudnn benchmark
161
+ cudnn_benchmark=False,
162
+ # set multi process parameters
163
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
164
+ # set distributed parameters
165
+ dist_cfg=dict(backend='nccl'),
166
+ )
167
+
168
+ # set visualizer
169
+ visualizer = None
170
+
171
+ # set log level
172
+ log_level = 'INFO'
173
+
174
+ # load from which checkpoint
175
+ load_from = None
176
+
177
+ # whether to resume training from the loaded checkpoint
178
+ resume = False
179
+
180
+ # Defaults to use random seed and disable `deterministic`
181
+ randomness = dict(seed=None, deterministic=False)
182
+
183
+ # set log processor
184
+ log_processor = dict(by_epoch=False)
xtuner_config/internvl_v2_internlm2_2b_lora_finetune_food.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
3
+ LoggerHook, ParamSchedulerHook)
4
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
5
+ from peft import LoraConfig
6
+ from torch.optim import AdamW
7
+ from transformers import AutoTokenizer
8
+
9
+ from xtuner.dataset import InternVL_V1_5_Dataset
10
+ from xtuner.dataset.collate_fns import default_collate_fn
11
+ from xtuner.dataset.samplers import LengthGroupedSampler
12
+ from xtuner.engine.hooks import DatasetInfoHook
13
+ from xtuner.engine.runner import TrainLoop
14
+ from xtuner.model import InternVL_V1_5
15
+ from xtuner.utils import PROMPT_TEMPLATE
16
+
17
+ #######################################################################
18
+ # PART 1 Settings #
19
+ #######################################################################
20
+ # Model
21
+ path = '/root/share/new_models/OpenGVLab/InternVL2-2B'
22
+
23
+ # Data
24
+ data_root = '/root/share/datasets/FoodieQA/' # your data path
25
+ data_path = data_root + 'sivqa_llava.json'
26
+ image_folder = data_root # your image folder path
27
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
28
+ max_length = 8192
29
+
30
+ # Scheduler & Optimizer
31
+ batch_size = 4 # per_device
32
+ accumulative_counts = 2
33
+ dataloader_num_workers = 4
34
+ max_epochs = 10
35
+ optim_type = AdamW
36
+ # official 1024 -> 4e-5
37
+ # lr = 1e-6
38
+ lr = 3e-5
39
+ betas = (0.9, 0.999)
40
+ weight_decay = 0.05
41
+ max_norm = 1 # grad clip
42
+ warmup_ratio = 0.03
43
+
44
+ # Save
45
+ save_steps = 64
46
+ save_total_limit = -1 # Maximum checkpoints to keep (-1 means unlimited)
47
+
48
+ #######################################################################
49
+ # PART 2 Model & Tokenizer & Image Processor #
50
+ #######################################################################
51
+ model = dict(
52
+ type=InternVL_V1_5,
53
+ model_path=path,
54
+ freeze_llm=True,
55
+ freeze_visual_encoder=True,
56
+ # comment the following lines if you don't want to use Lora in llm
57
+ llm_lora=dict(
58
+ type=LoraConfig,
59
+ r=128,
60
+ lora_alpha=256,
61
+ lora_dropout=0.05,
62
+ target_modules=None,
63
+ task_type='CAUSAL_LM'),
64
+ # uncomment the following lines if you don't want to use Lora in visual encoder # noqa
65
+ # visual_encoder_lora=dict(
66
+ # type=LoraConfig, r=64, lora_alpha=16, lora_dropout=0.05,
67
+ # target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'])
68
+ )
69
+
70
+ #######################################################################
71
+ # PART 3 Dataset & Dataloader #
72
+ #######################################################################
73
+ llava_dataset = dict(
74
+ type=InternVL_V1_5_Dataset,
75
+ model_path=path,
76
+ data_paths=data_path,
77
+ image_folders=image_folder,
78
+ template=prompt_template,
79
+ max_length=max_length)
80
+
81
+ train_dataloader = dict(
82
+ batch_size=batch_size,
83
+ num_workers=dataloader_num_workers,
84
+ dataset=llava_dataset,
85
+ sampler=dict(
86
+ type=LengthGroupedSampler,
87
+ length_property='modality_length',
88
+ per_device_batch_size=batch_size * accumulative_counts),
89
+ collate_fn=dict(type=default_collate_fn))
90
+
91
+ #######################################################################
92
+ # PART 4 Scheduler & Optimizer #
93
+ #######################################################################
94
+ # optimizer
95
+ optim_wrapper = dict(
96
+ type=AmpOptimWrapper,
97
+ optimizer=dict(
98
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
99
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
100
+ accumulative_counts=accumulative_counts,
101
+ loss_scale='dynamic',
102
+ dtype='float16')
103
+
104
+ # learning policy
105
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
106
+ param_scheduler = [
107
+ dict(
108
+ type=LinearLR,
109
+ start_factor=1e-5,
110
+ by_epoch=True,
111
+ begin=0,
112
+ end=warmup_ratio * max_epochs,
113
+ convert_to_iter_based=True),
114
+ dict(
115
+ type=CosineAnnealingLR,
116
+ eta_min=0.0,
117
+ by_epoch=True,
118
+ begin=warmup_ratio * max_epochs,
119
+ end=max_epochs,
120
+ convert_to_iter_based=True)
121
+ ]
122
+
123
+ # train, val, test setting
124
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
125
+
126
+ #######################################################################
127
+ # PART 5 Runtime #
128
+ #######################################################################
129
+ # Log the dialogue periodically during the training process, optional
130
+ tokenizer = dict(
131
+ type=AutoTokenizer.from_pretrained,
132
+ pretrained_model_name_or_path=path,
133
+ trust_remote_code=True)
134
+
135
+ custom_hooks = [
136
+ dict(type=DatasetInfoHook, tokenizer=tokenizer),
137
+ ]
138
+
139
+ # configure default hooks
140
+ default_hooks = dict(
141
+ # record the time of every iteration.
142
+ timer=dict(type=IterTimerHook),
143
+ # print log every 10 iterations.
144
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
145
+ # enable the parameter scheduler.
146
+ param_scheduler=dict(type=ParamSchedulerHook),
147
+ # save checkpoint per `save_steps`.
148
+ checkpoint=dict(
149
+ type=CheckpointHook,
150
+ save_optimizer=False,
151
+ by_epoch=False,
152
+ interval=save_steps,
153
+ max_keep_ckpts=save_total_limit),
154
+ # set sampler seed in distributed evrionment.
155
+ sampler_seed=dict(type=DistSamplerSeedHook),
156
+ )
157
+
158
+ # configure environment
159
+ env_cfg = dict(
160
+ # whether to enable cudnn benchmark
161
+ cudnn_benchmark=False,
162
+ # set multi process parameters
163
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
164
+ # set distributed parameters
165
+ dist_cfg=dict(backend='nccl'),
166
+ )
167
+
168
+ # set visualizer
169
+ visualizer = None
170
+
171
+ # set log level
172
+ log_level = 'INFO'
173
+
174
+ # load from which checkpoint
175
+ load_from = None
176
+
177
+ # whether to resume training from the loaded checkpoint
178
+ resume = False
179
+
180
+ # Defaults to use random seed and disable `deterministic`
181
+ randomness = dict(seed=None, deterministic=False)
182
+
183
+ # set log processor
184
+ log_processor = dict(by_epoch=False)