Spaces:
Running
Running
import cv2 as cv | |
import numpy as np | |
import torch | |
from gensim import models | |
import xgboost as xgb | |
import XGBoost_utils | |
import sys | |
import joblib | |
from DL_models import CustomResNet | |
#Ad/Brand Gaze Prediction | |
#Now the model is only able to process magazine images or images with full-page counterpages | |
#Please indicate where is the ad by ad_location parameter: left <- ad_location=0, right <- ad_location=1; otherwise, set it as None | |
def Ad_Gaze_Prediction(input_ad_path, input_ctpg_path, ad_location, | |
text_detection_model_path, LDA_model_pth, training_ad_text_dictionary_path, training_lang_preposition_path, | |
training_language, ad_embeddings, ctpg_embeddings, | |
surface_sizes=None, Product_Group=None, TextBoxes=None, Obj_and_Topics=None, | |
obj_detection_model_pth=None, num_topic=20, Gaze_Time_Type='Brand', Info_printing=True): | |
##Image Loading | |
if Info_printing: print('Loading Image ......') | |
flag_full_page_ad = False | |
has_ctpg = True | |
if type(input_ad_path) == str: | |
ad_img = cv.imread(input_ad_path) | |
ad_img = cv.cvtColor(ad_img, cv.COLOR_BGR2RGB) | |
ad_img_dim1, ad_img_dim2 = ad_img.shape[:2] | |
dim1_scale = int(np.ceil(ad_img_dim1/32)) | |
dim2_scale = int(np.ceil(ad_img_dim2/32)) | |
ad_img = cv.resize(ad_img, (32*dim2_scale,32*dim1_scale)) | |
else: | |
ad_img = input_ad_path | |
if input_ctpg_path is None: | |
ctpg_img = None #Initialization | |
flag_full_page_ad = True | |
has_ctpg = False | |
else: | |
if type(input_ctpg_path) == str: | |
ctpg_img = cv.imread(input_ctpg_path) | |
ctpg_img = cv.cvtColor(ctpg_img, cv.COLOR_BGR2RGB) | |
ctpg_img_dim1, ctpg_img_dim2 = ctpg_img.shape[:2] | |
dim1_scale = int(np.ceil(ctpg_img_dim1/32)) | |
dim2_scale = int(np.ceil(ctpg_img_dim2/32)) | |
ctpg_img = cv.resize(ctpg_img, (32*dim2_scale,32*dim1_scale)) | |
else: | |
ctpg_img = input_ctpg_path | |
#ctpg_img_dim1, ctpg_img_dim2 = [None,None] | |
# ctpg_img = None #Initialization | |
# flag_full_page_ad = False | |
# if has_ctpg: | |
# img = cv.resize(img, (1280,1024)) | |
# h, w, _ = img.shape | |
# page_width = w // 2 | |
# ctpg_location = 1-ad_location | |
# ad_img = img[:, (ad_location*page_width):((ad_location+1)*page_width)] | |
# ctpg_img = img[:, (ctpg_location*page_width):((ctpg_location+1)*page_width)] | |
# else: | |
# #if image's width is larger its height, then treat it as a double-page ad | |
# h, w, _ = img.shape | |
# if w > h: | |
# ad_img = cv.resize(img, (1280,1024)) | |
# flag_full_page_ad = True | |
# else: | |
# ad_img = cv.resize(img, (640,1024)) | |
if Info_printing: print() | |
##File Size | |
if Info_printing: print('Calculating complexity (filsize) ......') | |
filesize_ad = XGBoost_utils.filesize_individual(input_ad_path) | |
if has_ctpg: | |
filesize_ctpg = XGBoost_utils.filesize_individual(input_ctpg_path) | |
else: | |
filesize_ctpg = 0 | |
if Info_printing: print() | |
##Salience | |
if Info_printing: print('Processing Salience Information ......') | |
#Salience Map | |
S_map_ad = XGBoost_utils.Itti_Saliency(ad_img, scale_final=3) | |
if has_ctpg: | |
S_map_ctpg = XGBoost_utils.Itti_Saliency(ctpg_img, scale_final=3) | |
#K-Mean | |
threshold = 0.001 | |
enhance_rate = 1 | |
num_clusters = 3 | |
if flag_full_page_ad: | |
width = S_map_ad.shape[1] | |
left = S_map_ad[:, :width//2] | |
vecs_left, km_left = XGBoost_utils.salience_matrix_conv(left,threshold,num_clusters,enhance_rate=enhance_rate) | |
_,scores_left,widths_left,D_left = XGBoost_utils.img_clusters(num_clusters, left, km_left.labels_, km_left.cluster_centers_, vecs_left) | |
right = S_map_ad[:, width//2:] | |
vecs_right, km_right = XGBoost_utils.salience_matrix_conv(right,threshold,num_clusters,enhance_rate=enhance_rate) | |
_,scores_right,widths_right,D_right = XGBoost_utils.img_clusters(num_clusters, right, km_right.labels_, km_right.cluster_centers_, vecs_right) | |
ad_sal = np.array(scores_left) + np.array(scores_right) | |
ad_width = np.array(widths_left) + np.array(widths_right); ad_width = np.log(ad_width+1) | |
ad_sig_obj = D_left + D_right | |
ctpg_sal = np.zeros_like(ad_sal) | |
ctpg_width = np.zeros_like(ad_width) | |
ctpg_sig_obj = 0 | |
else: | |
vecs, km = XGBoost_utils.salience_matrix_conv(S_map_ad,threshold,num_clusters,enhance_rate=enhance_rate) | |
_,scores,widths,D = XGBoost_utils.img_clusters(num_clusters, S_map_ad, km.labels_, km.cluster_centers_, vecs) | |
ad_sal = np.array(scores) | |
ad_width = np.log(np.array(widths)+1) | |
ad_sig_obj = D | |
if has_ctpg: | |
vecs, km = XGBoost_utils.salience_matrix_conv(S_map_ctpg,threshold,num_clusters,enhance_rate=enhance_rate) | |
_,scores,widths,D = XGBoost_utils.img_clusters(num_clusters, S_map_ctpg, km.labels_, km.cluster_centers_, vecs) | |
ctpg_sal = np.array(scores) | |
ctpg_width = np.log(np.array(widths)+1) | |
ctpg_sig_obj = D | |
else: | |
ctpg_sal = np.zeros_like(ad_sal) | |
ctpg_width = np.zeros_like(ad_width) | |
ctpg_sig_obj = 0 | |
if Info_printing: print() | |
##Number of Textboxes | |
if Info_printing: print('Processing Textboxes ......') | |
if TextBoxes is None: | |
#Need multiples of 32 in both dimensions | |
ad_num_textboxes = XGBoost_utils.text_detection_east(ad_img, text_detection_model_path) | |
if has_ctpg: | |
ctpg_num_textboxes = XGBoost_utils.text_detection_east(ctpg_img, text_detection_model_path) | |
else: | |
ctpg_num_textboxes = 0 | |
else: | |
ad_num_textboxes, ctpg_num_textboxes = TextBoxes | |
if Info_printing: print() | |
##Objects and Topic Difference | |
if Info_printing: print('Processing Object and Topic Information ......') | |
if Info_printing: print('Loading Object Detection Model') | |
if Obj_and_Topics is None: | |
if obj_detection_model_pth is None: | |
model_obj = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, trust_repo=True) | |
else: | |
model_obj = torch.load(obj_detection_model_pth) | |
model_lda = models.LdaModel.load(LDA_model_pth) | |
dictionary = torch.load(training_ad_text_dictionary_path) | |
dutch_preposition = torch.load(training_lang_preposition_path) | |
ad_num_objs, ctpg_num_objs, ad_topic_weights, topic_Diff = XGBoost_utils.object_and_topic_variables(ad_img, ctpg_img, has_ctpg, dictionary, | |
dutch_preposition, training_language, model_obj, | |
model_lda, num_topic) | |
else: | |
ad_num_objs, ctpg_num_objs, ad_topic_soft_weights, ctpg_topic_soft_weights = Obj_and_Topics | |
indx = np.argmax(ad_topic_soft_weights) | |
ad_topic_weights = np.zeros(num_topic) | |
ad_topic_weights[indx] = 1 | |
topic_Diff = XGBoost_utils.KL_dist(ad_topic_soft_weights, ctpg_topic_soft_weights) | |
if Info_printing: print() | |
##Left and Right Indicator | |
if Info_printing: print('Getting Left/Right Indicator ......') | |
if flag_full_page_ad: | |
Left_right_indicator = [1,1] | |
else: | |
if has_ctpg: | |
if ad_location == 0: | |
Left_right_indicator = [1,0] | |
elif ad_location == 1: | |
Left_right_indicator = [0,1] | |
else: | |
Left_right_indicator = [1,1] | |
else: | |
Left_right_indicator = [1,0] | |
if Info_printing: print() | |
##Product Category | |
if Info_printing: print('Getting Product Category Indicator ......') | |
if Product_Group is None: | |
group_ind = XGBoost_utils.product_category() | |
else: | |
group_ind = Product_Group | |
if Info_printing: print() | |
##Surface Sizes | |
if Info_printing: print('Getting Surface Sizes ......') | |
if surface_sizes is None: | |
ad_img = cv.cvtColor(ad_img, cv.COLOR_RGB2BGR) | |
print('Please select the bounding box for your ad (from top left to bottom right)') | |
A = XGBoost_utils.Region_Selection(ad_img) | |
print() | |
print('Please select the bounding box for brands (from top left to bottom right)') | |
B = XGBoost_utils.Region_Selection(ad_img) | |
print() | |
print('Please select the bounding box for texts (from top left to bottom right)') | |
T = XGBoost_utils.Region_Selection(ad_img) | |
surface_sizes = [B/A*100,(1-B/A-T/A)*100,T/A*100,sum(Left_right_indicator)*5] | |
##Typicality Measure | |
# if Info_printing: print('Calculating Typicality Measure ......') | |
# if Info_printing: print() | |
##Get All things together | |
if Info_printing: print('Predicting ......') | |
gaze = 0 | |
for i in range(10): | |
#Var construction | |
pca_topic_transform = joblib.load('Topic_Embedding_PCAs/pca_model_'+str(i)+'.pkl') | |
ad_topics_curr = pca_topic_transform.transform(ad_embeddings)[:,:4][0] | |
ctpg_topics_curr = pca_topic_transform.transform(ctpg_embeddings)[:,:4][0] | |
ad_topic_weights = ad_topics_curr | |
topic_Diff = np.linalg.norm(ad_topics_curr-ctpg_topics_curr) | |
X = surface_sizes+[filesize_ad,filesize_ctpg]+list(ad_sal)+list(ctpg_sal)+list(ad_width)+list(ctpg_width)+[ad_sig_obj,ctpg_sig_obj]+[ad_num_textboxes,ctpg_num_textboxes,ad_num_objs,ctpg_num_objs]+list(group_ind)+list(ad_topic_weights) | |
X = np.array(X).reshape(1,len(X)) | |
X_for_typ = list(X[0,[0,1,2,3,4,6,7,8,12,13,14,18,20,22]])+list(group_ind)+list(ad_topic_weights) | |
X_for_typ = np.array(X_for_typ).reshape(1,len(X_for_typ)) | |
if Gaze_Time_Type == 'Brand': | |
med = torch.load('Brand_Gaze_Model/typicality_train_medoid') | |
elif Gaze_Time_Type == 'Ad': | |
med = torch.load('Ad_Gaze_Model/typicality_train_medoid') | |
typ = XGBoost_utils.typ_cat(med, X_for_typ, group_ind, np.abs) | |
Var = surface_sizes+[filesize_ad,filesize_ctpg]+list(ad_sal)+list(ctpg_sal)+list(ad_width)+list(ctpg_width)+[ad_sig_obj,ctpg_sig_obj]+[ad_num_textboxes,ctpg_num_textboxes,ad_num_objs,ctpg_num_objs]+Left_right_indicator+list(ad_topic_weights)+list(group_ind)+[topic_Diff.item(),typ.item()] | |
Var = np.array(Var).reshape(1,len(Var)) | |
xgb_model = xgb.XGBRegressor() | |
if Gaze_Time_Type == 'Brand': | |
xgb_model.load_model('Brand_Gaze_Model/10_models/Model_'+str(i+1)+'.json') | |
elif Gaze_Time_Type == 'Ad': | |
xgb_model.load_model('Ad_Gaze_Model/10_models/Model_'+str(i+1)+'.json') | |
gaze += xgb_model.predict(Var) | |
gaze = gaze/10 | |
if Info_printing: print('The predicted '+Gaze_Time_Type+' gaze time is: ', (np.exp(gaze)-1).item()) | |
return (np.exp(gaze)-1).item() | |
def CNN_Prediction(adv_imgs, ctpg_imgs, ad_locations, Gaze_Type='AG'): #Gaze_Type='AG' or 'BG' | |
gaze = 0 | |
if torch.cuda.is_available(): | |
device = 'cuda' | |
elif torch.backends.mps.is_available(): | |
device = 'mps' | |
else: | |
device = 'cpu' | |
if Gaze_Type == 'AG': | |
a_temp = 0.2590; b_temp = 1.1781 #AG | |
elif Gaze_Type == 'BG': | |
a_temp = 0.2100; b_temp = 0.3541 #BG | |
for i in range(10): | |
net = CustomResNet() | |
net.load_state_dict(torch.load('CNN_Gaze_Model/Fine-tune_'+Gaze_Type+'/Model_'+str(i)+'.pth',map_location=torch.device('cpu'))) | |
net = net.to(device) | |
with torch.no_grad(): | |
pred = net.forward(adv_imgs, ctpg_imgs, ad_locations) | |
pred = torch.exp(pred*a_temp+b_temp) - 1 | |
gaze += pred/10 | |
return gaze |