Demo750 commited on
Commit
497c61f
1 Parent(s): d4ce556

Update Webpage.py

Browse files
Files changed (1) hide show
  1. Webpage.py +69 -28
Webpage.py CHANGED
@@ -6,6 +6,9 @@ import numpy as np
6
  import cv2 as cv
7
  import torch
8
  from PIL import Image
 
 
 
9
 
10
  GENERAL_CATEGORY = {'Potatoes / Vegetables / Fruit': 0, 'Chemical products': 1, 'Photo / Film / Optical items': 2, 'Catering industry': 3, 'Industrial products other': 4, 'Media': 5, 'Real estate': 6, 'Government': 7, 'Personnel advertisements': 8, 'Cars / Commercial vehicles': 9, 'Cleaning products': 10, 'Retail': 11, 'Fragrances': 12, 'Footwear / Leather goods': 13, 'Software / Automation': 14, 'Telecommunication equipment': 15, 'Tourism': 16, 'Transport/Communication companies': 17, 'Transport services': 18, 'Insurances': 19, 'Meat / Fish / Poultry': 20, 'Detergents': 21, 'Foods General': 22, 'Other services': 23, 'Banks and Financial Services': 24, 'Office Products': 25, 'Household Items': 26, 'Non-alcoholic beverages': 27, 'Hair, Oral and Personal Care': 28, 'Fashion and Clothing': 29, 'Other products and Services': 30, 'Paper products': 31, 'Alcohol and Other Stimulants': 32, 'Medicines': 33, 'Recreation and Leisure': 34, 'Electronics': 35, 'Home Furnishings': 36, 'Products for Business Use': 37}
11
  CATEGORIES = list(GENERAL_CATEGORY.keys())
@@ -44,10 +47,13 @@ def calculate_areas(prompts, brand_num, pictorial_num, text_num):
44
  left_margin = x1; right_margin = w-x2
45
  if left_margin <=100 and right_margin <= 100:
46
  upper_margin = y1; lower_margin = h-y2
47
- if upper_margin >= lower_margin:
48
- context_image = image_entire[:int(y1), :, :]
49
  else:
50
- context_image = image_entire[int(y2):, :, :]
 
 
 
51
  else:
52
  if left_margin >= right_margin:
53
  context_image = image_entire[:, :int(x1), :]
@@ -59,8 +65,14 @@ def calculate_areas(prompts, brand_num, pictorial_num, text_num):
59
 
60
  return (brand_surf/whole_size*100, pictorial_surf/whole_size*100, text_surf/whole_size*100, ad_size/whole_size*100, ad_image, context_image)
61
 
 
 
 
 
 
62
 
63
- def attention(notes, download1, download2, whole_display_prompt,
 
64
  brand_num, pictorial_num, text_num,
65
  category, ad_location, gaze_type):
66
  text_detection_model_path = 'EAST-Text-Detection/frozen_east_text_detection.pb'
@@ -82,7 +94,10 @@ def attention(notes, download1, download2, whole_display_prompt,
82
  surfaces = [brand_percent, visual_percent, text_percent, adv_size_percent*10/100]
83
 
84
  # caption_ad = XGBoost_utils.Caption_Generation(Image.fromarray(np.uint8(ad_image)))
85
- # caption_context = XGBoost_utils.Caption_Generation(Image.fromarray(np.uint8(context_image)))
 
 
 
86
  # ad_topic = XGBoost_utils.Topic_emb(caption_ad)
87
  # ctpg_topic = XGBoost_utils.Topic_emb(caption_context)
88
  np.random.seed(42)
@@ -91,10 +106,16 @@ def attention(notes, download1, download2, whole_display_prompt,
91
 
92
  ad = cv.resize(ad_image, (640, 832))
93
  print('ad shape: ', ad.shape)
94
- context = cv.resize(context_image, (640, 832))
 
 
 
95
 
96
  adv_imgs = torch.permute(torch.tensor(ad), (2,0,1)).unsqueeze(0)
97
- ctpg_imgs = torch.permute(torch.tensor(context), (2,0,1)).unsqueeze(0)
 
 
 
98
  ad_locations = torch.tensor([1,0]).unsqueeze(0)
99
  heatmap = Predict.HeatMap_CNN(adv_imgs, ctpg_imgs, ad_locations, Gaze_Type='AG')
100
 
@@ -104,27 +125,50 @@ def attention(notes, download1, download2, whole_display_prompt,
104
  ad_embeddings=ad_topic, ctpg_embeddings=ctpg_topic,
105
  surface_sizes=surfaces, Product_Group=prod_group,
106
  obj_detection_model_pth=None, num_topic=20, Gaze_Time_Type=gaze_type)
107
- return np.round(Gaze,2), Image.fromarray(np.flip(heatmap, axis=2)), "Hotter/Redder regions show more pixel contribution."
108
 
109
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  gr.Interface(
111
  fn=attention,
112
  inputs=[gr.Markdown("""
113
- ### Instruction:
114
- 1. Click to upload or drag the entire image that contains BOTH ad and its context;
115
- 2. Draw bounding boxes in the order of: (each element can have more than 1 boxes; remember the number of boxes for each element you draw)
116
- &nbsp;&nbsp;&nbsp;(a) Brand element(s) (skip if N.A.)
117
- &nbsp;&nbsp;&nbsp;(b) Pictorial element(s), e.g. Objects, Person etc (skip if N.A.)
118
- &nbsp;&nbsp;&nbsp;(c) Text element(s) (skip if N.A.)
119
- &nbsp;&nbsp;&nbsp;(d) The advertisement.
120
- 3. Put in number of bounding boxes for each element, product category, ad location and attention type.
121
-
122
- ***NOTE:*** *ResNet50 Heatmap could take around 20-80 seconds under current CPU environment.*
123
-
124
- Two example ads are avialable for download: """),
 
125
  gr.DownloadButton(label="Download Example Image 1 of Ad and Context", value='Demo/Ad_Example1.jpg'),
126
  gr.DownloadButton(label="Download Example Image 2 of Ad and Context", value='Demo/Ad_Example2.jpg'),
127
- ImagePrompter(label="Upload Entire (Ad+Context) Image, and Draw Bounding Boxes", sources=['upload'], type="pil"),
128
  gr.Number(label="Number of brand bounding boxes drawn"),
129
  gr.Number(label="Number of pictorial bounding boxes drawn"),
130
  gr.Number(label="Number of text bounding boxes drawn"),
@@ -133,12 +177,9 @@ with gr.Blocks() as demo:
133
  gr.Dropdown(GAZE_TYPE, label='Gaze Type')
134
  ],
135
  outputs=[gr.Number(label="Predicted Gaze (sec)"),
136
- gr.Image(label="ResNet50 Heatmap"),
137
- gr.Textbox(label="Heatmap Info")],
138
- title="""Gazer 1.0: Ad Attention Prediction""",
139
- description="""This app accompanies: "Contextual Advertising with Theory-Informed Machine Learning", manuscript submitted to the Journal of Marketing.
140
- App Version: 1.0, Date: 10/24/2024.
141
- Note: Gazer 1.0 does not yet include LLM generated ad topics. Future updates will include this in a GPU environment.""",
142
  theme=gr.themes.Soft()
143
  )
144
  gr.Markdown(
@@ -149,4 +190,4 @@ with gr.Blocks() as demo:
149
  """
150
  )
151
 
152
- demo.launch(share=True)
 
6
  import cv2 as cv
7
  import torch
8
  from PIL import Image
9
+ from gradio_pdf import PDF
10
+ from pdf2image import convert_from_path
11
+ from pathlib import Path
12
 
13
  GENERAL_CATEGORY = {'Potatoes / Vegetables / Fruit': 0, 'Chemical products': 1, 'Photo / Film / Optical items': 2, 'Catering industry': 3, 'Industrial products other': 4, 'Media': 5, 'Real estate': 6, 'Government': 7, 'Personnel advertisements': 8, 'Cars / Commercial vehicles': 9, 'Cleaning products': 10, 'Retail': 11, 'Fragrances': 12, 'Footwear / Leather goods': 13, 'Software / Automation': 14, 'Telecommunication equipment': 15, 'Tourism': 16, 'Transport/Communication companies': 17, 'Transport services': 18, 'Insurances': 19, 'Meat / Fish / Poultry': 20, 'Detergents': 21, 'Foods General': 22, 'Other services': 23, 'Banks and Financial Services': 24, 'Office Products': 25, 'Household Items': 26, 'Non-alcoholic beverages': 27, 'Hair, Oral and Personal Care': 28, 'Fashion and Clothing': 29, 'Other products and Services': 30, 'Paper products': 31, 'Alcohol and Other Stimulants': 32, 'Medicines': 33, 'Recreation and Leisure': 34, 'Electronics': 35, 'Home Furnishings': 36, 'Products for Business Use': 37}
14
  CATEGORIES = list(GENERAL_CATEGORY.keys())
 
47
  left_margin = x1; right_margin = w-x2
48
  if left_margin <=100 and right_margin <= 100:
49
  upper_margin = y1; lower_margin = h-y2
50
+ if upper_margin <= 100 and lower_margin <= 100:
51
+ context_image = None
52
  else:
53
+ if upper_margin >= lower_margin:
54
+ context_image = image_entire[:int(y1), :, :]
55
+ else:
56
+ context_image = image_entire[int(y2):, :, :]
57
  else:
58
  if left_margin >= right_margin:
59
  context_image = image_entire[:, :int(x1), :]
 
65
 
66
  return (brand_surf/whole_size*100, pictorial_surf/whole_size*100, text_surf/whole_size*100, ad_size/whole_size*100, ad_image, context_image)
67
 
68
+ def convert(note, doc):
69
+ print(doc)
70
+ img = convert_from_path(doc)[0]
71
+ img.save(f'pdf_to_imgs/pdf_img.png', 'PNG')
72
+ return 'Done!', gr.DownloadButton(label='Download converted image', value='pdf_to_imgs/pdf_img.png')
73
 
74
+ def attention(note, button1, button2,
75
+ whole_display_prompt,
76
  brand_num, pictorial_num, text_num,
77
  category, ad_location, gaze_type):
78
  text_detection_model_path = 'EAST-Text-Detection/frozen_east_text_detection.pb'
 
94
  surfaces = [brand_percent, visual_percent, text_percent, adv_size_percent*10/100]
95
 
96
  # caption_ad = XGBoost_utils.Caption_Generation(Image.fromarray(np.uint8(ad_image)))
97
+ # if context_image is not None:
98
+ # caption_context = XGBoost_utils.Caption_Generation(Image.fromarray(np.uint8(context_image)))
99
+ # else:
100
+ # caption_context = ''
101
  # ad_topic = XGBoost_utils.Topic_emb(caption_ad)
102
  # ctpg_topic = XGBoost_utils.Topic_emb(caption_context)
103
  np.random.seed(42)
 
106
 
107
  ad = cv.resize(ad_image, (640, 832))
108
  print('ad shape: ', ad.shape)
109
+ if context_image is None:
110
+ context = None
111
+ else:
112
+ context = cv.resize(context_image, (640, 832))
113
 
114
  adv_imgs = torch.permute(torch.tensor(ad), (2,0,1)).unsqueeze(0)
115
+ if context is None:
116
+ ctpg_imgs = torch.zeros_like(adv_imgs)
117
+ else:
118
+ ctpg_imgs = torch.permute(torch.tensor(context), (2,0,1)).unsqueeze(0)
119
  ad_locations = torch.tensor([1,0]).unsqueeze(0)
120
  heatmap = Predict.HeatMap_CNN(adv_imgs, ctpg_imgs, ad_locations, Gaze_Type='AG')
121
 
 
125
  ad_embeddings=ad_topic, ctpg_embeddings=ctpg_topic,
126
  surface_sizes=surfaces, Product_Group=prod_group,
127
  obj_detection_model_pth=None, num_topic=20, Gaze_Time_Type=gaze_type)
128
+ return np.round(Gaze,2), Image.fromarray(np.flip(heatmap, axis=2))
129
 
130
  with gr.Blocks() as demo:
131
+ gr.Markdown("""
132
+ <div style='text-align: center; padding: 10px; font-size:40px'>
133
+ <p> <b>Gazer 1.0: Ad Attention Prediction</b> </p>
134
+ </div>
135
+ """)
136
+ gr.Markdown("""
137
+ This app accompanies: "Contextual Advertising with Theory-Informed Machine Learning", manuscript submitted to the Journal of Marketing.
138
+ App Version: 1.0, Date: 10/24/2024.
139
+ Note: Gazer 1.0 does not yet include LLM generated ad topics. Future updates will include this in a GPU environment.
140
+ """)
141
+ gr.Interface(
142
+ fn=convert,
143
+ inputs=[gr.Markdown("""
144
+ <div style='font-size:20px'>
145
+ <p> <b>If you only have a pdf image file, first convert it here to png file and download:</b> </p>
146
+ </div>
147
+
148
+ """),
149
+ PDF(label="PDF Converter")],
150
+ outputs=[gr.Text(label='Progress'), gr.DownloadButton(label='Wait to be downloadable', value=None)]
151
+ )
152
+
153
  gr.Interface(
154
  fn=attention,
155
  inputs=[gr.Markdown("""
156
+ ## Instructions:
157
+ 0. The screen size should remain the same during processing.
158
+ 1. Click to upload or drag the entire image (jpg/jpeg/png file) that contains BOTH ad and its context;
159
+ 2. Draw bounding boxes in the order of: (each element can have more than 1 boxes; remember the number of boxes for each element you draw)
160
+ &nbsp;&nbsp;&nbsp;(a) Brand element(s) (skip if N.A.)
161
+ &nbsp;&nbsp;&nbsp;(b) Pictorial element(s), e.g. Objects, Person etc (skip if N.A.)
162
+ &nbsp;&nbsp;&nbsp;(c) Text element(s) (skip if N.A.)
163
+ &nbsp;&nbsp;&nbsp;(d) The advertisement.
164
+ 3. Put in number of bounding boxes for each element, product category, ad location and attention type.
165
+
166
+ ***NOTE:*** *ResNet50 Heatmap could take around 20-80 seconds under current CPU environment.*
167
+
168
+ Two example ads are avialable for download: """),
169
  gr.DownloadButton(label="Download Example Image 1 of Ad and Context", value='Demo/Ad_Example1.jpg'),
170
  gr.DownloadButton(label="Download Example Image 2 of Ad and Context", value='Demo/Ad_Example2.jpg'),
171
+ ImagePrompter(label="Upload Entire (Ad+Context) Image in jpg/jpeg/png format, and Draw Bounding Boxes", sources=['upload'], type="pil"),
172
  gr.Number(label="Number of brand bounding boxes drawn"),
173
  gr.Number(label="Number of pictorial bounding boxes drawn"),
174
  gr.Number(label="Number of text bounding boxes drawn"),
 
177
  gr.Dropdown(GAZE_TYPE, label='Gaze Type')
178
  ],
179
  outputs=[gr.Number(label="Predicted Gaze (sec)"),
180
+ gr.Image(label="ResNet50 Heatmap (Hotter/Redder regions show more pixel contribution.)")],
181
+ title=None,
182
+ description=None,
 
 
 
183
  theme=gr.themes.Soft()
184
  )
185
  gr.Markdown(
 
190
  """
191
  )
192
 
193
+ demo.launch(share=False)