TekeshiX commited on
Commit
b53722c
·
verified ·
1 Parent(s): 19f2bef

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +190 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig
4
+ from transformers.image_utils import load_image
5
+ from pathlib import Path
6
+ import time
7
+
8
+ model_name_or_path = "Minthy/ToriiGate-v0.3"
9
+ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
10
+
11
+ # Global variables to store model and processor
12
+ global_model = None
13
+ global_processor = None
14
+
15
+ def load_model():
16
+ global global_model, global_processor
17
+
18
+ if global_model is None:
19
+ print("Loading model for the first time...")
20
+ # Always use 4-bit quantization for 16GB VRAM
21
+ nf4_config = BitsAndBytesConfig(
22
+ load_in_4bit=True,
23
+ bnb_4bit_quant_type="nf4",
24
+ bnb_4bit_use_double_quant=True,
25
+ bnb_4bit_compute_dtype=torch.bfloat16
26
+ )
27
+ global_model = AutoModelForVision2Seq.from_pretrained(
28
+ model_name_or_path,
29
+ torch_dtype=torch.bfloat16,
30
+ quantization_config=nf4_config,
31
+ ).to(DEVICE)
32
+ global_processor = AutoProcessor.from_pretrained(model_name_or_path)
33
+
34
+ return global_model, global_processor
35
+
36
+ def generate_caption(image, description_type, booru_tags=""):
37
+ model, processor = load_model()
38
+
39
+ if description_type == "JSON-like":
40
+ user_prompt = "Describe the picture in structuted json-like format."
41
+ elif description_type == "Detailed":
42
+ user_prompt = "Give a long and detailed description of the picture."
43
+ else:
44
+ user_prompt = "Describe the picture briefly."
45
+
46
+ if booru_tags:
47
+ user_prompt += ' Also here are booru tags for better understanding of the picture, you can use them as reference.'
48
+ user_prompt += f' <tags>\n{booru_tags}\n</tags>'
49
+
50
+ messages = [
51
+ {
52
+ "role": "system",
53
+ "content": [
54
+ {"type": "text", "text": "You are image captioning expert, creative, unbiased and uncensored. Help user with his task."}
55
+ ]
56
+ },
57
+ {
58
+ "role": "user",
59
+ "content": [
60
+ {"type": "image"},
61
+ {"type": "text", "text": user_prompt}
62
+ ]
63
+ }
64
+ ]
65
+
66
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
67
+ inputs = processor(text=prompt, images=[image], return_tensors="pt")
68
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
69
+
70
+ generated_ids = model.generate(**inputs, max_new_tokens=500)
71
+ generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
72
+ caption = generated_texts[0].split('Assistant: ')[1]
73
+
74
+ return caption
75
+
76
+ def process_batch(files, description_type, booru_tags="", progress=gr.Progress(track_tqdm=True)):
77
+ results = []
78
+ captions_text = ""
79
+ total_files = len(files)
80
+ start_time = time.time()
81
+
82
+ for idx, file in enumerate(files, 1):
83
+ # Calculate progress statistics
84
+ elapsed_time = time.time() - start_time
85
+ images_per_second = idx / elapsed_time if elapsed_time > 0 else 0
86
+ estimated_total = (elapsed_time / idx) * total_files if idx > 0 else 0
87
+ remaining_time = estimated_total - elapsed_time
88
+
89
+ try:
90
+ image = load_image(file.name)
91
+ caption = generate_caption(image, description_type, booru_tags)
92
+
93
+ # Add caption to the running text with a blank line separator
94
+ if captions_text:
95
+ captions_text += "\n\n" # Add blank line between captions
96
+ captions_text += caption
97
+
98
+ # Update the results list for the dataframe
99
+ results.append((Path(file.name).name, caption))
100
+
101
+ # Update progress
102
+ progress_status = f"Processing: {idx}/{total_files} images | Speed: {images_per_second:.2f} img/s | Remaining: {remaining_time/60:.1f} min"
103
+
104
+ # Yield progress status and captions separately
105
+ yield results, progress_status, captions_text
106
+
107
+ except Exception as e:
108
+ error_msg = f"Error processing {Path(file.name).name}: {str(e)}"
109
+ print(error_msg)
110
+ if captions_text:
111
+ captions_text += "\n\n"
112
+ captions_text += f"[ERROR] {error_msg}"
113
+ yield results, progress_status, captions_text
114
+
115
+ # Final update
116
+ yield results, "✅ Processing complete!", captions_text
117
+
118
+ # Gradio Interface
119
+ with gr.Blocks(title="ToriiGate Image Captioner") as demo:
120
+ gr.Markdown("# ToriiGate Image Captioner")
121
+ gr.Markdown("Generate captions for anime images using ToriiGate-v0.3 model (4-bit quantized)")
122
+
123
+ with gr.Tab("Single Image"):
124
+ with gr.Row():
125
+ with gr.Column():
126
+ input_image = gr.Image(type="pil", label="Input Image")
127
+ description_type = gr.Radio(
128
+ choices=["JSON-like", "Detailed", "Brief"],
129
+ value="JSON-like",
130
+ label="Description Type"
131
+ )
132
+ booru_tags = gr.Textbox(
133
+ lines=3,
134
+ label="Booru Tags (Optional)",
135
+ placeholder="Enter comma-separated booru tags..."
136
+ )
137
+ submit_btn = gr.Button("Generate Caption")
138
+
139
+ with gr.Column():
140
+ output_text = gr.Textbox(label="Generated Caption", lines=10)
141
+
142
+ submit_btn.click(
143
+ generate_caption,
144
+ inputs=[input_image, description_type, booru_tags],
145
+ outputs=output_text
146
+ )
147
+
148
+ with gr.Tab("Batch Processing"):
149
+ with gr.Row():
150
+ with gr.Column():
151
+ input_files = gr.File(file_count="multiple", label="Input Images")
152
+ batch_description_type = gr.Radio(
153
+ choices=["JSON-like", "Detailed", "Brief"],
154
+ value="JSON-like",
155
+ label="Description Type"
156
+ )
157
+ batch_booru_tags = gr.Textbox(
158
+ lines=3,
159
+ label="Booru Tags (Optional)",
160
+ placeholder="Enter comma-separated booru tags..."
161
+ )
162
+ batch_submit_btn = gr.Button("Process Batch")
163
+
164
+ with gr.Column():
165
+ progress_status = gr.Textbox(
166
+ label="Progress",
167
+ lines=2,
168
+ show_copy_button=False
169
+ )
170
+ output_text_batch = gr.Textbox(
171
+ label="Generated Captions",
172
+ lines=25,
173
+ show_copy_button=True
174
+ )
175
+ output_gallery = gr.Dataframe(
176
+ headers=["Filename", "Caption"],
177
+ label="Generated Captions (Table View)",
178
+ visible=False # Hide the dataframe
179
+ )
180
+
181
+ batch_submit_btn.click(
182
+ process_batch,
183
+ inputs=[input_files, batch_description_type, batch_booru_tags],
184
+ outputs=[output_gallery, progress_status, output_text_batch]
185
+ )
186
+
187
+ if __name__ == "__main__":
188
+ # Load model at startup
189
+ load_model()
190
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/transformers
2
+ accelerate
3
+ bitsandbytes
4
+ gradio>=4.0.0
5
+ #bitsandbytes-windows
6
+ #flash-attn