FrugalDisinfoHunter / climate_model.py
Zen0's picture
Upload 2 files
24ff08a verified
raw
history blame
10.3 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from typing import Dict, List, Optional, Tuple
import numpy as np
import logging
logger = logging.getLogger(__name__)
class MetadataAttention(nn.Module):
"""Attention mechanism for combining text and metadata features"""
def __init__(self, text_dim: int, metadata_dim: int):
super().__init__()
self.text_linear = nn.Linear(text_dim, 64)
self.metadata_linear = nn.Linear(metadata_dim, 64)
self.attention = nn.Sequential(
nn.Linear(128, 64),
nn.Tanh(),
nn.Linear(64, 1),
nn.Softmax(dim=1)
)
def forward(self, text_features: torch.Tensor, metadata_features: torch.Tensor) -> torch.Tensor:
text_proj = self.text_linear(text_features)
meta_proj = self.metadata_linear(metadata_features)
meta_proj = meta_proj.unsqueeze(1).expand(-1, text_proj.size(1), -1)
combined = torch.cat([text_proj, meta_proj], dim=-1)
weights = self.attention(combined)
weighted_sum = (text_features * weights).sum(dim=1)
return weighted_sum
class FeatureEncoder(nn.Module):
"""Encodes numerical and categorical features"""
def __init__(self, num_numerical_features: int, categorical_feature_dims: Dict[str, int]):
super().__init__()
# Numerical features
self.numerical_bn = nn.BatchNorm1d(num_numerical_features)
self.numerical_encoder = nn.Sequential(
nn.Linear(num_numerical_features, 64),
nn.LayerNorm(64),
nn.ReLU(),
nn.Dropout(0.2)
)
# Categorical features
self.categorical_encoders = nn.ModuleDict()
self.categorical_dims = {}
for feature_name, dim in categorical_feature_dims.items():
self.categorical_encoders[feature_name] = nn.Sequential(
nn.Embedding(dim, 32),
nn.Linear(32, 32),
nn.ReLU()
)
self.categorical_dims[feature_name] = dim
self.output_dim = 64 + 32 * len(categorical_feature_dims)
def forward(self, numerical_features: torch.Tensor,
categorical_features: Dict[str, torch.Tensor]) -> torch.Tensor:
x_num = self.numerical_bn(numerical_features)
x_num = self.numerical_encoder(x_num)
x_cat_list = []
for feature_name, encoder in self.categorical_encoders.items():
if feature_name in categorical_features:
x_cat = encoder(categorical_features[feature_name])
x_cat_list.append(x_cat)
if x_cat_list:
x_cat = torch.cat(x_cat_list, dim=1)
return torch.cat([x_num, x_cat], dim=1)
return x_num
class ClimateDisinformationModel(nn.Module):
"""Model for climate disinformation classification"""
def __init__(self,
num_classes: int,
base_model_name: str = "google/mobilebert-uncased",
num_numerical_features: int = 10,
categorical_feature_dims: Optional[Dict[str, int]] = None,
device: Optional[torch.device] = None):
super().__init__()
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if categorical_feature_dims is None:
categorical_feature_dims = {}
try:
# Text encoder
self.text_encoder = AutoModel.from_pretrained(base_model_name)
hidden_size = self.text_encoder.config.hidden_size
# Feature processing
self.feature_encoder = FeatureEncoder(
num_numerical_features,
categorical_feature_dims
)
# Metadata attention
self.metadata_attention = MetadataAttention(
text_dim=hidden_size,
metadata_dim=self.feature_encoder.output_dim
)
# Classifier
combined_dim = hidden_size + self.feature_encoder.output_dim
self.classifier = nn.Sequential(
nn.Linear(combined_dim, 256),
nn.LayerNorm(256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, num_classes)
)
# Loss function (will be set by set_class_weights)
self.criterion = nn.CrossEntropyLoss()
# Move to device
self.to(self.device)
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
raise
def set_class_weights(self, class_weights: torch.Tensor):
"""Set class weights for loss function"""
try:
self.criterion = nn.CrossEntropyLoss(weight=class_weights.to(self.device))
except Exception as e:
logger.error(f"Error setting class weights: {str(e)}")
raise
def forward(self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
numerical_features: torch.Tensor,
categorical_features: Dict[str, torch.Tensor],
labels: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
try:
# Get text features
text_outputs = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask
)
text_features = text_outputs.last_hidden_state
# Get enhanced features
feature_embedding = self.feature_encoder(
numerical_features,
categorical_features
)
# Apply metadata attention
text_features = self.metadata_attention(
text_features,
feature_embedding
)
# Combine features
combined_embedding = torch.cat([text_features, feature_embedding], dim=1)
# Get logits
logits = self.classifier(combined_embedding)
# Prepare output dict
outputs = {"logits": logits}
# Calculate loss if labels provided
if labels is not None:
outputs["loss"] = self.criterion(logits, labels)
return outputs
except Exception as e:
logger.error(f"Error in forward pass: {str(e)}")
raise
class ModelWrapper:
"""Wrapper for model management and inference"""
def __init__(self,
num_classes: int,
base_model_name: str = "google/mobilebert-uncased",
num_numerical_features: int = 10,
categorical_feature_dims: Optional[Dict[str, int]] = None,
device: Optional[torch.device] = None):
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {self.device}")
try:
# Initialize tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
self.model = ClimateDisinformationModel(
num_classes=num_classes,
base_model_name=base_model_name,
num_numerical_features=num_numerical_features,
categorical_feature_dims=categorical_feature_dims,
device=self.device
)
except Exception as e:
logger.error(f"Error initializing ModelWrapper: {str(e)}")
raise
def train_step(self,
batch: Dict[str, torch.Tensor],
optimizer: torch.optim.Optimizer) -> Tuple[float, torch.Tensor]:
"""Single training step"""
try:
# Set model to training mode
self.model.train()
# Zero gradients
optimizer.zero_grad()
# Forward pass
outputs = self.model(**batch)
loss = outputs["loss"]
# Backward pass
loss.backward()
# Optimizer step
optimizer.step()
return loss, outputs["logits"]
except Exception as e:
logger.error(f"Error in training step: {str(e)}")
raise
def predict(self,
texts: List[str],
numerical_features: np.ndarray,
categorical_features: Dict[str, np.ndarray]) -> np.ndarray:
"""Batch prediction"""
try:
self.model.eval()
# Prepare inputs
inputs = self.tokenizer(
texts,
return_tensors="pt",
max_length=128,
truncation=True,
padding="max_length"
)
# Move inputs to device
inputs = {k: v.to(self.device) for k, v in inputs.items()}
num_features = torch.FloatTensor(numerical_features).to(self.device)
cat_features = {
k: torch.LongTensor(v).to(self.device)
for k, v in categorical_features.items()
}
# Get predictions
with torch.no_grad():
outputs = self.model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
numerical_features=num_features,
categorical_features=cat_features
)
return F.softmax(outputs["logits"], dim=1).cpu().numpy()
except Exception as e:
logger.error(f"Error in prediction: {str(e)}")
raise