import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig from torchvision import models, transforms import torch.nn as nn import os import json import cv2 from PIL import Image import gradio as gr class MultimodalRiskBehaviorModel(nn.Module): def __init__(self, text_model_name="bert-base-uncased", hidden_dim=512, dropout=0.3): super(MultimodalRiskBehaviorModel, self).__init__() # Text model using AutoModelForSequenceClassification self.text_model_name = text_model_name self.text_model = AutoModelForSequenceClassification.from_pretrained(text_model_name, num_labels=2) # Visual model (ResNet50) self.visual_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) visual_feature_dim = self.visual_model.fc.in_features self.visual_model.fc = nn.Identity() # Fusion and classification layer setup text_feature_dim = self.text_model.config.hidden_size self.fc1 = nn.Linear(text_feature_dim + visual_feature_dim, hidden_dim) self.dropout = nn.Dropout(dropout) self.fc2 = nn.Linear(hidden_dim, 1) def forward(self, encoding, frames): input_ids = encoding['input_ids'].squeeze(1).to(device) attention_mask = encoding['attention_mask'].squeeze(1).to(device) # Extract text and visual features text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask).logits frames = frames.to(device) batch_size, num_frames, channels, height, width = frames.size() frames = frames.view(batch_size * num_frames, channels, height, width) visual_features = self.visual_model(frames) visual_features = visual_features.view(batch_size, num_frames, -1).mean(dim=1) # Combine and classify combined_features = torch.cat((text_features, visual_features), dim=1) x = self.dropout(torch.relu(self.fc1(combined_features))) output = torch.sigmoid(self.fc2(x)) return output def save_pretrained(self, save_directory): os.makedirs(save_directory, exist_ok=True) torch.save(self.state_dict(), os.path.join(save_directory, 'pytorch_model.bin')) config = { "text_model_name": self.text_model_name, "hidden_dim": self.fc1.out_features } with open(os.path.join(save_directory, 'config.json'), 'w') as f: json.dump(config, f) @classmethod def from_pretrained(cls, load_directory, map_location=None): if os.path.exists(load_directory): config_path = os.path.join(load_directory, 'config.json') state_dict_path = os.path.join(load_directory, 'pytorch_model.bin') with open(config_path, 'r') as f: config_dict = json.load(f) model = cls(text_model_name=config_dict["text_model_name"], hidden_dim=config_dict["hidden_dim"]) state_dict = torch.load(state_dict_path, map_location=map_location) model.load_state_dict(state_dict) else: hf_model = AutoModelForSequenceClassification.from_pretrained(load_directory, num_labels=2) model = cls(text_model_name=hf_model.config.name_or_path, hidden_dim=hf_model.config.hidden_size) model.text_model = hf_model return model tokenizer = AutoTokenizer.from_pretrained('Souha-BH/BERT_Resnet50') model = MultimodalRiskBehaviorModel.from_pretrained('Souha-BH/BERT_Resnet50') # if cpu add arg map_location='cpu' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Function to load frames from a video def load_frames_from_video(video_path, transform, num_frames=10): cap = cv2.VideoCapture(video_path) frames = [] frame_count = 0 while frame_count < num_frames: # Limit to a number of frames for efficiency success, frame = cap.read() if not success: break # Convert frame (NumPy array) to PIL image and apply transformations frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) frame = transform(frame) frames.append(frame) frame_count += 1 cap.release() # Stack frames and add batch dimension (1, num_frames, channels, height, width) frames = torch.stack(frames) frames = frames.unsqueeze(0) # Add batch dimension return frames def predict_video(model, video_path, text_input, tokenizer, transform): try: # Set model to evaluation mode model.eval() # Tokenize the text input encoding = tokenizer( text_input, padding='max_length', truncation=True, max_length=128, return_tensors='pt' ) encoding = {key: val.to(device) for key, val in encoding.items()} # Load frames from the video frames = load_frames_from_video(video_path, transform) frames = frames.to(device) # Log input shapes and devices print(f"Encoding device: {next(iter(encoding.values())).device}, Frames shape: {frames.shape}") # Perform forward pass through the model with torch.no_grad(): output = model(encoding, frames) # Apply sigmoid to get probability, then threshold to get prediction prediction = (output.squeeze(-1) > 0.5).float() return prediction.item() except Exception as e: print(f"Prediction error: {e}") return "Error during prediction" transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Define your video paths and captions video_paths = [ 'https://drive.google.com/uc?export=download&id=1iWq1q1LM-jmf4iZxOqZTw4FaIBekJowM', 'https://drive.google.com/uc?export=download&id=1_egBaC1HD2kIZgRRKsnCtsWG94vg1c7n', 'https://drive.google.com/uc?export=download&id=12cGxBEkfU5Q1Ezg2jRk6zGyn2hoR3JLj' ] video_captions = [ "Everytime i start a diet كل مرة أحاول أبدأ ريجيم 😓 #dietmemes #funnyvideos #animetiktok", "New sandwich from burger king 🍔👑 #mukbang #asmr #asmrmukbang #asmrsounds #eat #food #Foodie moe eats #yummy #cheese #chicken #burger #fries #burgerking @Burger King", "all workout guides l!nked in bi0 // honestly huge moment 😂 I’ve been so focused on growing my upper body that this feels like it finally shows! shorts from @KEEPTHATPUMP #upperbody #upperbodyworkout #glutegains #glutegrowth #gluteexercise #workout #strengthtraining #gym #trending #fyp" ] def predict_risk(video_index): video_path = video_paths[video_index] text_input = video_captions[video_index] # Make prediction prediction = predict_video(model, video_path, text_input, tokenizer, transform) # Return the corresponding label return "Risky Health Behavior" if prediction == 1 else "Not Risky Health Behavior" # Interface setup with gr.Blocks() as interface: gr.Markdown("# Risk Behavior Prediction") gr.Markdown("Select a video to classify its behavior as risky or not.") # Input option selector video_selector = gr.Radio(["Video 1", "Video 2", "Video 3"], label="Choose a Video") # Use function to return URLs which are handled by the Gradio `gr.Video` component def show_selected_video(choice): idx = int(choice.split()[-1]) - 1 return video_paths[idx], f"**Caption:** {video_captions[idx]}" video_player = gr.Video(width=320, height=240) caption_box = gr.Markdown() video_selector.change( fn=show_selected_video, inputs=video_selector, outputs=[video_player, caption_box] ) # Prediction button and output predict_button = gr.Button("Predict Risk") output_text = gr.Textbox(label="Prediction") predict_button.click( fn=lambda idx: predict_risk(int(idx.split()[-1]) - 1), inputs=video_selector, outputs=output_text ) # Launch the app interface.launch()