Update app.py
Browse files
app.py
CHANGED
@@ -184,9 +184,17 @@ def generate_caption(text_model, tokenizer, image_features, prompt_str: str, max
|
|
184 |
|
185 |
return tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0].strip()
|
186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
@spaces.GPU()
|
188 |
@torch.no_grad()
|
189 |
-
def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int, lens_type: str = "", film_stock: str = "", composition_style: str = "", lighting_aspect: str = "", special_technique: str = "", color_effect: str = "") -> str:
|
190 |
"""
|
191 |
Generate a caption or style prompt based on the input image and parameters.
|
192 |
"""
|
@@ -210,13 +218,14 @@ def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str,
|
|
210 |
|
211 |
if caption_type == "style_prompt":
|
212 |
prompt_str += f" Lens type: {lens_type} ({lens_types_info[lens_type]}). "
|
213 |
-
prompt_str += f"Film stock: {
|
214 |
-
prompt_str += f"Composition style: {
|
215 |
-
prompt_str += f"Lighting aspect: {
|
216 |
-
prompt_str += f"Special technique: {
|
217 |
-
prompt_str += f"Color effect: {
|
218 |
|
219 |
-
|
|
|
220 |
|
221 |
pixel_values = preprocess_image(input_image)
|
222 |
|
@@ -226,6 +235,17 @@ def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str,
|
|
226 |
embedded_images = image_adapter(image_features)
|
227 |
embedded_images = embedded_images.to('cuda')
|
228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
caption = generate_caption(text_model, tokenizer, embedded_images, prompt_str)
|
230 |
|
231 |
return caption
|
@@ -492,7 +512,7 @@ with gr.Blocks(theme="Hev832/Applio", css=css) as demo:
|
|
492 |
|
493 |
caption_type.change(update_style_options, inputs=[caption_type], outputs=[lens_type, film_stock, composition_style, lighting_aspect, special_technique, color_effect])
|
494 |
|
495 |
-
run_button.click(fn=stream_chat, inputs=[input_image, caption_type, caption_tone, caption_length, lens_type, film_stock, composition_style, lighting_aspect, special_technique, color_effect], outputs=[output_caption])
|
496 |
|
497 |
|
498 |
if __name__ == "__main__":
|
|
|
184 |
|
185 |
return tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0].strip()
|
186 |
|
187 |
+
# Add a dropdown menu for model selection
|
188 |
+
model_selection = gr.Dropdown(
|
189 |
+
choices=["llama", "Qwen/Qwen2.5-7B-Instruct"],
|
190 |
+
label="Model Selection",
|
191 |
+
value="llama",
|
192 |
+
)
|
193 |
+
|
194 |
+
# Update the stream_chat function to accept the selected model
|
195 |
@spaces.GPU()
|
196 |
@torch.no_grad()
|
197 |
+
def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int, model_selection: str, lens_type: str = "", film_stock: str = "", composition_style: str = "", lighting_aspect: str = "", special_technique: str = "", color_effect: str = "") -> str:
|
198 |
"""
|
199 |
Generate a caption or style prompt based on the input image and parameters.
|
200 |
"""
|
|
|
218 |
|
219 |
if caption_type == "style_prompt":
|
220 |
prompt_str += f" Lens type: {lens_type} ({lens_types_info[lens_type]}). "
|
221 |
+
prompt_str += f"Film stock: {film_stocks_info[film_stock]}). "
|
222 |
+
prompt_str += f"Composition style: {composition_styles_info[composition_style]}). "
|
223 |
+
prompt_str += f"Lighting aspect: {lighting_aspects_info[lighting_aspect]}). "
|
224 |
+
prompt_str += f"Special technique: {special_techniques_info[special_technique]}). "
|
225 |
+
prompt_str += f"Color effect: {color_effects_info[color_effect]})."
|
226 |
|
227 |
+
# Debugging: Print the constructed prompt string
|
228 |
+
print(f"Constructed Prompt: {prompt_str}")
|
229 |
|
230 |
pixel_values = preprocess_image(input_image)
|
231 |
|
|
|
235 |
embedded_images = image_adapter(image_features)
|
236 |
embedded_images = embedded_images.to('cuda')
|
237 |
|
238 |
+
# Load the selected model
|
239 |
+
if model_selection == "llama":
|
240 |
+
text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16)
|
241 |
+
else:
|
242 |
+
text_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B-Instruct", device_map="auto", torch_dtype=torch.bfloat16)
|
243 |
+
|
244 |
+
text_model.eval()
|
245 |
+
|
246 |
+
# Debugging: Print the prompt string before passing to generate_caption
|
247 |
+
print(f"Prompt passed to generate_caption: {prompt_str}")
|
248 |
+
|
249 |
caption = generate_caption(text_model, tokenizer, embedded_images, prompt_str)
|
250 |
|
251 |
return caption
|
|
|
512 |
|
513 |
caption_type.change(update_style_options, inputs=[caption_type], outputs=[lens_type, film_stock, composition_style, lighting_aspect, special_technique, color_effect])
|
514 |
|
515 |
+
run_button.click(fn=stream_chat, inputs=[input_image, caption_type, caption_tone, caption_length, model_selection, lens_type, film_stock, composition_style, lighting_aspect, special_technique, color_effect], outputs=[output_caption])
|
516 |
|
517 |
|
518 |
if __name__ == "__main__":
|