arad1367 commited on
Commit
3e5dff9
β€’
1 Parent(s): 5ba3d26

Upload 8 files

Browse files
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, Idefics3ForConditionalGeneration
3
+ import re
4
+ import time
5
+ from PIL import Image
6
+ import torch
7
+ import spaces
8
+ import subprocess
9
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
+
11
+ processor = AutoProcessor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3")
12
+ model = Idefics3ForConditionalGeneration.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3",
13
+ torch_dtype=torch.bfloat16,
14
+ trust_remote_code=True).to("cuda")
15
+
16
+ BAD_WORDS_IDS = processor.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
17
+ EOS_WORDS_IDS = [processor.tokenizer.eos_token_id]
18
+
19
+ @spaces.GPU
20
+ def model_inference(
21
+ images, text, assistant_prefix, decoding_strategy, temperature, max_new_tokens,
22
+ repetition_penalty, top_p
23
+ ):
24
+ if text == "" and not images:
25
+ gr.Error("Please input a query and optionally image(s).")
26
+
27
+ if text == "" and images:
28
+ gr.Error("Please input a text query along the image(s).")
29
+
30
+ # Check if the input is related to marketing
31
+ marketing_keywords = ["product", "brand", "advertisement", "marketing", "strategy", "comparison", "analysis", "trend", "audience"]
32
+ if not any(keyword in text.lower() for keyword in marketing_keywords):
33
+ return "Your question is not in the marketing area. Please upload another image or try again."
34
+
35
+ if isinstance(images, Image.Image):
36
+ images = [images]
37
+
38
+ resulting_messages = [
39
+ {
40
+ "role": "user",
41
+ "content": [{"type": "image"}] + [
42
+ {"type": "text", "text": text}
43
+ ]
44
+ }
45
+ ]
46
+
47
+ if assistant_prefix:
48
+ text = f"{assistant_prefix} {text}"
49
+
50
+ prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
51
+ inputs = processor(text=prompt, images=[images], return_tensors="pt")
52
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
53
+
54
+ generation_args = {
55
+ "max_new_tokens": max_new_tokens,
56
+ "repetition_penalty": repetition_penalty,
57
+ }
58
+
59
+ assert decoding_strategy in [
60
+ "Greedy",
61
+ "Top P Sampling",
62
+ ]
63
+ if decoding_strategy == "Greedy":
64
+ generation_args["do_sample"] = False
65
+ elif decoding_strategy == "Top P Sampling":
66
+ generation_args["temperature"] = temperature
67
+ generation_args["do_sample"] = True
68
+ generation_args["top_p"] = top_p
69
+
70
+ generation_args.update(inputs)
71
+
72
+ # Generate
73
+ generated_ids = model.generate(**generation_args)
74
+
75
+ generated_texts = processor.batch_decode(generated_ids[:, generation_args["input_ids"].size(1):], skip_special_tokens=True)
76
+ return generated_texts[0]
77
+
78
+ with gr.Blocks(fill_height=True) as demo:
79
+ gr.Markdown("## Marketing Vision App πŸ“ˆ")
80
+ gr.Markdown("This app uses the [HuggingFaceM4/Idefics3-8B-Llama3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) model to answer questions related to marketing. Upload an image and a text query, or try one of the examples.")
81
+ gr.Markdown("**Disclaimer:** This app may not consistently follow prompts or handle complex tasks. However, adding a prefix to the assistant's response can significantly improve the output. You could also play with the parameters such as the temperature in non-greedy mode.")
82
+ with gr.Column():
83
+ image_input = gr.Image(label="Upload your Image", type="pil", scale=1)
84
+ query_input = gr.Textbox(label="Prompt")
85
+ assistant_prefix = gr.Textbox(label="Assistant Prefix", placeholder="Let's think step by step.")
86
+
87
+ submit_btn = gr.Button("Submit")
88
+ output = gr.Textbox(label="Output")
89
+
90
+ with gr.Accordion(label="Example Inputs and Advanced Generation Parameters"):
91
+ examples=[
92
+ ["example_images/iphone_vs_samsung.jpg", "I want to buy a smartphone with features similar to the one in the photo. Suggest models and provide a comparison.", None, "Product Recommendation", 0.5, 768, 1.1, 0.85],
93
+ ["example_images/ad_analysis.png", "Analyze this advertisement and explain its marketing strategy.", None, "Advertisement Analysis", 0.5, 768, 1.1, 0.85],
94
+ ["example_images/market_trends.png", "Analyze this marketing chart and explain the market trends it represents.", None, "Market Trend Analysis", 0.5, 768, 1.1, 0.85],
95
+ ["example_images/social_media_post.png", "Critique this social media post about product quality and suggest improvements.", None, "Social Media Post Analysis", 0.5, 768, 1.1, 0.85],
96
+ ["example_images/brand_comparison.jpg", "Compare these three brands based on their price strategies.", None, "Brand Comparison", 0.5, 768, 1.1, 0.85],
97
+ ["example_images/target_audience.jpg", "Analyze this image and suggest a target audience for the product.", None, "Target Audience Analysis", 0.5, 768, 1.1, 0.85]
98
+ ]
99
+
100
+ # Hyper-parameters for generation
101
+ max_new_tokens = gr.Slider(
102
+ minimum=8,
103
+ maximum=1024,
104
+ value=512,
105
+ step=1,
106
+ interactive=True,
107
+ label="Maximum number of new tokens to generate",
108
+ )
109
+ repetition_penalty = gr.Slider(
110
+ minimum=0.01,
111
+ maximum=5.0,
112
+ value=1.2,
113
+ step=0.01,
114
+ interactive=True,
115
+ label="Repetition penalty",
116
+ info="1.0 is equivalent to no penalty",
117
+ )
118
+ temperature = gr.Slider(
119
+ minimum=0.0,
120
+ maximum=5.0,
121
+ value=0.4,
122
+ step=0.1,
123
+ interactive=True,
124
+ label="Sampling temperature",
125
+ info="Higher values will produce more diverse outputs.",
126
+ )
127
+ top_p = gr.Slider(
128
+ minimum=0.01,
129
+ maximum=0.99,
130
+ value=0.8,
131
+ step=0.01,
132
+ interactive=True,
133
+ label="Top P",
134
+ info="Higher values is equivalent to sampling more low-probability tokens.",
135
+ )
136
+ decoding_strategy = gr.Radio(
137
+ [
138
+ "Greedy",
139
+ "Top P Sampling",
140
+ ],
141
+ value="Greedy",
142
+ label="Decoding strategy",
143
+ interactive=True,
144
+ info="Higher values is equivalent to sampling more low-probability tokens.",
145
+ )
146
+ decoding_strategy.change(
147
+ fn=lambda selection: gr.Slider(
148
+ visible=(
149
+ selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"]
150
+ )
151
+ ),
152
+ inputs=decoding_strategy,
153
+ outputs=temperature,
154
+ )
155
+
156
+ decoding_strategy.change(
157
+ fn=lambda selection: gr.Slider(
158
+ visible=(
159
+ selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"]
160
+ )
161
+ ),
162
+ inputs=decoding_strategy,
163
+ outputs=repetition_penalty,
164
+ )
165
+ decoding_strategy.change(
166
+ fn=lambda selection: gr.Slider(visible=(selection in ["Top P Sampling"])),
167
+ inputs=decoding_strategy,
168
+ outputs=top_p,
169
+ )
170
+ gr.Examples(
171
+ examples = examples,
172
+ inputs=[image_input, query_input, assistant_prefix, decoding_strategy, temperature,
173
+ max_new_tokens, repetition_penalty, top_p],
174
+ outputs=output,
175
+ fn=model_inference
176
+ )
177
+
178
+ submit_btn.click(model_inference, inputs = [image_input, query_input, assistant_prefix, decoding_strategy, temperature,
179
+ max_new_tokens, repetition_penalty, top_p], outputs=output)
180
+
181
+ demo.launch(debug=True)
example_images/ad_analysis.png ADDED
example_images/brand_comparison.jpg ADDED
example_images/iphone_vs_samsung.jpg ADDED
example_images/market_trends.png ADDED
example_images/social_media_post.png ADDED
example_images/target_audience.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ accelerate
3
+ huggingface_hub
4
+ gradio
5
+ git+https://github.com/andimarafioti/transformers.git@idefics3
6
+ spaces