|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from transformers import AutoModel, AutoTokenizer
|
|
from typing import Dict, List, Optional, Tuple
|
|
import numpy as np
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class MetadataAttention(nn.Module):
|
|
"""Attention mechanism for combining text and metadata features"""
|
|
def __init__(self, text_dim: int, metadata_dim: int):
|
|
super().__init__()
|
|
self.text_linear = nn.Linear(text_dim, 64)
|
|
self.metadata_linear = nn.Linear(metadata_dim, 64)
|
|
self.attention = nn.Sequential(
|
|
nn.Linear(128, 64),
|
|
nn.Tanh(),
|
|
nn.Linear(64, 1),
|
|
nn.Softmax(dim=1)
|
|
)
|
|
|
|
def forward(self, text_features: torch.Tensor, metadata_features: torch.Tensor) -> torch.Tensor:
|
|
text_proj = self.text_linear(text_features)
|
|
meta_proj = self.metadata_linear(metadata_features)
|
|
meta_proj = meta_proj.unsqueeze(1).expand(-1, text_proj.size(1), -1)
|
|
combined = torch.cat([text_proj, meta_proj], dim=-1)
|
|
weights = self.attention(combined)
|
|
weighted_sum = (text_features * weights).sum(dim=1)
|
|
return weighted_sum
|
|
|
|
class FeatureEncoder(nn.Module):
|
|
"""Encodes numerical and categorical features"""
|
|
def __init__(self, num_numerical_features: int, categorical_feature_dims: Dict[str, int]):
|
|
super().__init__()
|
|
|
|
|
|
self.numerical_bn = nn.BatchNorm1d(num_numerical_features)
|
|
self.numerical_encoder = nn.Sequential(
|
|
nn.Linear(num_numerical_features, 64),
|
|
nn.LayerNorm(64),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.2)
|
|
)
|
|
|
|
|
|
self.categorical_encoders = nn.ModuleDict()
|
|
self.categorical_dims = {}
|
|
for feature_name, dim in categorical_feature_dims.items():
|
|
self.categorical_encoders[feature_name] = nn.Sequential(
|
|
nn.Embedding(dim, 32),
|
|
nn.Linear(32, 32),
|
|
nn.ReLU()
|
|
)
|
|
self.categorical_dims[feature_name] = dim
|
|
|
|
self.output_dim = 64 + 32 * len(categorical_feature_dims)
|
|
|
|
def forward(self, numerical_features: torch.Tensor,
|
|
categorical_features: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
x_num = self.numerical_bn(numerical_features)
|
|
x_num = self.numerical_encoder(x_num)
|
|
|
|
x_cat_list = []
|
|
for feature_name, encoder in self.categorical_encoders.items():
|
|
if feature_name in categorical_features:
|
|
x_cat = encoder(categorical_features[feature_name])
|
|
x_cat_list.append(x_cat)
|
|
|
|
if x_cat_list:
|
|
x_cat = torch.cat(x_cat_list, dim=1)
|
|
return torch.cat([x_num, x_cat], dim=1)
|
|
return x_num
|
|
|
|
class ClimateDisinformationModel(nn.Module):
|
|
"""Model for climate disinformation classification"""
|
|
def __init__(self,
|
|
num_classes: int,
|
|
base_model_name: str = "google/mobilebert-uncased",
|
|
num_numerical_features: int = 10,
|
|
categorical_feature_dims: Optional[Dict[str, int]] = None,
|
|
device: Optional[torch.device] = None):
|
|
super().__init__()
|
|
|
|
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
if categorical_feature_dims is None:
|
|
categorical_feature_dims = {}
|
|
|
|
try:
|
|
|
|
self.text_encoder = AutoModel.from_pretrained(base_model_name)
|
|
hidden_size = self.text_encoder.config.hidden_size
|
|
|
|
|
|
self.feature_encoder = FeatureEncoder(
|
|
num_numerical_features,
|
|
categorical_feature_dims
|
|
)
|
|
|
|
|
|
self.metadata_attention = MetadataAttention(
|
|
text_dim=hidden_size,
|
|
metadata_dim=self.feature_encoder.output_dim
|
|
)
|
|
|
|
|
|
combined_dim = hidden_size + self.feature_encoder.output_dim
|
|
self.classifier = nn.Sequential(
|
|
nn.Linear(combined_dim, 256),
|
|
nn.LayerNorm(256),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.1),
|
|
nn.Linear(256, num_classes)
|
|
)
|
|
|
|
|
|
self.criterion = nn.CrossEntropyLoss()
|
|
|
|
|
|
self.to(self.device)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error initializing model: {str(e)}")
|
|
raise
|
|
|
|
def set_class_weights(self, class_weights: torch.Tensor):
|
|
"""Set class weights for loss function"""
|
|
try:
|
|
self.criterion = nn.CrossEntropyLoss(weight=class_weights.to(self.device))
|
|
except Exception as e:
|
|
logger.error(f"Error setting class weights: {str(e)}")
|
|
raise
|
|
|
|
def forward(self,
|
|
input_ids: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
numerical_features: torch.Tensor,
|
|
categorical_features: Dict[str, torch.Tensor],
|
|
labels: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
|
try:
|
|
|
|
text_outputs = self.text_encoder(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask
|
|
)
|
|
text_features = text_outputs.last_hidden_state
|
|
|
|
|
|
feature_embedding = self.feature_encoder(
|
|
numerical_features,
|
|
categorical_features
|
|
)
|
|
|
|
|
|
text_features = self.metadata_attention(
|
|
text_features,
|
|
feature_embedding
|
|
)
|
|
|
|
|
|
combined_embedding = torch.cat([text_features, feature_embedding], dim=1)
|
|
|
|
|
|
logits = self.classifier(combined_embedding)
|
|
|
|
|
|
outputs = {"logits": logits}
|
|
|
|
|
|
if labels is not None:
|
|
outputs["loss"] = self.criterion(logits, labels)
|
|
|
|
return outputs
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in forward pass: {str(e)}")
|
|
raise
|
|
|
|
class ModelWrapper:
|
|
"""Wrapper for model management and inference"""
|
|
def __init__(self,
|
|
num_classes: int,
|
|
base_model_name: str = "google/mobilebert-uncased",
|
|
num_numerical_features: int = 10,
|
|
categorical_feature_dims: Optional[Dict[str, int]] = None,
|
|
device: Optional[torch.device] = None):
|
|
|
|
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
logger.info(f"Using device: {self.device}")
|
|
|
|
try:
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
|
self.model = ClimateDisinformationModel(
|
|
num_classes=num_classes,
|
|
base_model_name=base_model_name,
|
|
num_numerical_features=num_numerical_features,
|
|
categorical_feature_dims=categorical_feature_dims,
|
|
device=self.device
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error initializing ModelWrapper: {str(e)}")
|
|
raise
|
|
|
|
def train_step(self,
|
|
batch: Dict[str, torch.Tensor],
|
|
optimizer: torch.optim.Optimizer) -> Tuple[float, torch.Tensor]:
|
|
"""Single training step"""
|
|
try:
|
|
|
|
self.model.train()
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
outputs = self.model(**batch)
|
|
loss = outputs["loss"]
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
|
optimizer.step()
|
|
|
|
return loss, outputs["logits"]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in training step: {str(e)}")
|
|
raise
|
|
|
|
def predict(self,
|
|
texts: List[str],
|
|
numerical_features: np.ndarray,
|
|
categorical_features: Dict[str, np.ndarray]) -> np.ndarray:
|
|
"""Batch prediction"""
|
|
try:
|
|
self.model.eval()
|
|
|
|
|
|
inputs = self.tokenizer(
|
|
texts,
|
|
return_tensors="pt",
|
|
max_length=128,
|
|
truncation=True,
|
|
padding="max_length"
|
|
)
|
|
|
|
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
num_features = torch.FloatTensor(numerical_features).to(self.device)
|
|
cat_features = {
|
|
k: torch.LongTensor(v).to(self.device)
|
|
for k, v in categorical_features.items()
|
|
}
|
|
|
|
|
|
with torch.no_grad():
|
|
outputs = self.model(
|
|
input_ids=inputs["input_ids"],
|
|
attention_mask=inputs["attention_mask"],
|
|
numerical_features=num_features,
|
|
categorical_features=cat_features
|
|
)
|
|
|
|
return F.softmax(outputs["logits"], dim=1).cpu().numpy()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in prediction: {str(e)}")
|
|
raise |