Delete climate_model.py
Browse files- climate_model.py +0 -272
climate_model.py
DELETED
@@ -1,272 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
from transformers import AutoModel, AutoTokenizer
|
5 |
-
from typing import Dict, List, Optional, Tuple
|
6 |
-
import numpy as np
|
7 |
-
import logging
|
8 |
-
|
9 |
-
logger = logging.getLogger(__name__)
|
10 |
-
|
11 |
-
class MetadataAttention(nn.Module):
|
12 |
-
"""Attention mechanism for combining text and metadata features"""
|
13 |
-
def __init__(self, text_dim: int, metadata_dim: int):
|
14 |
-
super().__init__()
|
15 |
-
self.text_linear = nn.Linear(text_dim, 64)
|
16 |
-
self.metadata_linear = nn.Linear(metadata_dim, 64)
|
17 |
-
self.attention = nn.Sequential(
|
18 |
-
nn.Linear(128, 64),
|
19 |
-
nn.Tanh(),
|
20 |
-
nn.Linear(64, 1),
|
21 |
-
nn.Softmax(dim=1)
|
22 |
-
)
|
23 |
-
|
24 |
-
def forward(self, text_features: torch.Tensor, metadata_features: torch.Tensor) -> torch.Tensor:
|
25 |
-
text_proj = self.text_linear(text_features)
|
26 |
-
meta_proj = self.metadata_linear(metadata_features)
|
27 |
-
meta_proj = meta_proj.unsqueeze(1).expand(-1, text_proj.size(1), -1)
|
28 |
-
combined = torch.cat([text_proj, meta_proj], dim=-1)
|
29 |
-
weights = self.attention(combined)
|
30 |
-
weighted_sum = (text_features * weights).sum(dim=1)
|
31 |
-
return weighted_sum
|
32 |
-
|
33 |
-
class FeatureEncoder(nn.Module):
|
34 |
-
"""Encodes numerical and categorical features"""
|
35 |
-
def __init__(self, num_numerical_features: int, categorical_feature_dims: Dict[str, int]):
|
36 |
-
super().__init__()
|
37 |
-
|
38 |
-
# Numerical features
|
39 |
-
self.numerical_bn = nn.BatchNorm1d(num_numerical_features)
|
40 |
-
self.numerical_encoder = nn.Sequential(
|
41 |
-
nn.Linear(num_numerical_features, 64),
|
42 |
-
nn.LayerNorm(64),
|
43 |
-
nn.ReLU(),
|
44 |
-
nn.Dropout(0.2)
|
45 |
-
)
|
46 |
-
|
47 |
-
# Categorical features
|
48 |
-
self.categorical_encoders = nn.ModuleDict()
|
49 |
-
self.categorical_dims = {}
|
50 |
-
for feature_name, dim in categorical_feature_dims.items():
|
51 |
-
self.categorical_encoders[feature_name] = nn.Sequential(
|
52 |
-
nn.Embedding(dim, 32),
|
53 |
-
nn.Linear(32, 32),
|
54 |
-
nn.ReLU()
|
55 |
-
)
|
56 |
-
self.categorical_dims[feature_name] = dim
|
57 |
-
|
58 |
-
self.output_dim = 64 + 32 * len(categorical_feature_dims)
|
59 |
-
|
60 |
-
def forward(self, numerical_features: torch.Tensor,
|
61 |
-
categorical_features: Dict[str, torch.Tensor]) -> torch.Tensor:
|
62 |
-
x_num = self.numerical_bn(numerical_features)
|
63 |
-
x_num = self.numerical_encoder(x_num)
|
64 |
-
|
65 |
-
x_cat_list = []
|
66 |
-
for feature_name, encoder in self.categorical_encoders.items():
|
67 |
-
if feature_name in categorical_features:
|
68 |
-
x_cat = encoder(categorical_features[feature_name])
|
69 |
-
x_cat_list.append(x_cat)
|
70 |
-
|
71 |
-
if x_cat_list:
|
72 |
-
x_cat = torch.cat(x_cat_list, dim=1)
|
73 |
-
return torch.cat([x_num, x_cat], dim=1)
|
74 |
-
return x_num
|
75 |
-
|
76 |
-
class ClimateDisinformationModel(nn.Module):
|
77 |
-
"""Model for climate disinformation classification"""
|
78 |
-
def __init__(self,
|
79 |
-
num_classes: int,
|
80 |
-
base_model_name: str = "google/mobilebert-uncased",
|
81 |
-
num_numerical_features: int = 10,
|
82 |
-
categorical_feature_dims: Optional[Dict[str, int]] = None,
|
83 |
-
device: Optional[torch.device] = None):
|
84 |
-
super().__init__()
|
85 |
-
|
86 |
-
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
87 |
-
|
88 |
-
if categorical_feature_dims is None:
|
89 |
-
categorical_feature_dims = {}
|
90 |
-
|
91 |
-
try:
|
92 |
-
# Text encoder
|
93 |
-
self.text_encoder = AutoModel.from_pretrained(base_model_name)
|
94 |
-
hidden_size = self.text_encoder.config.hidden_size
|
95 |
-
|
96 |
-
# Feature processing
|
97 |
-
self.feature_encoder = FeatureEncoder(
|
98 |
-
num_numerical_features,
|
99 |
-
categorical_feature_dims
|
100 |
-
)
|
101 |
-
|
102 |
-
# Metadata attention
|
103 |
-
self.metadata_attention = MetadataAttention(
|
104 |
-
text_dim=hidden_size,
|
105 |
-
metadata_dim=self.feature_encoder.output_dim
|
106 |
-
)
|
107 |
-
|
108 |
-
# Classifier
|
109 |
-
combined_dim = hidden_size + self.feature_encoder.output_dim
|
110 |
-
self.classifier = nn.Sequential(
|
111 |
-
nn.Linear(combined_dim, 256),
|
112 |
-
nn.LayerNorm(256),
|
113 |
-
nn.ReLU(),
|
114 |
-
nn.Dropout(0.1),
|
115 |
-
nn.Linear(256, num_classes)
|
116 |
-
)
|
117 |
-
|
118 |
-
# Loss function (will be set by set_class_weights)
|
119 |
-
self.criterion = nn.CrossEntropyLoss()
|
120 |
-
|
121 |
-
# Move to device
|
122 |
-
self.to(self.device)
|
123 |
-
|
124 |
-
except Exception as e:
|
125 |
-
logger.error(f"Error initializing model: {str(e)}")
|
126 |
-
raise
|
127 |
-
|
128 |
-
def set_class_weights(self, class_weights: torch.Tensor):
|
129 |
-
"""Set class weights for loss function"""
|
130 |
-
try:
|
131 |
-
self.criterion = nn.CrossEntropyLoss(weight=class_weights.to(self.device))
|
132 |
-
except Exception as e:
|
133 |
-
logger.error(f"Error setting class weights: {str(e)}")
|
134 |
-
raise
|
135 |
-
|
136 |
-
def forward(self,
|
137 |
-
input_ids: torch.Tensor,
|
138 |
-
attention_mask: torch.Tensor,
|
139 |
-
numerical_features: torch.Tensor,
|
140 |
-
categorical_features: Dict[str, torch.Tensor],
|
141 |
-
labels: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
142 |
-
try:
|
143 |
-
# Get text features
|
144 |
-
text_outputs = self.text_encoder(
|
145 |
-
input_ids=input_ids,
|
146 |
-
attention_mask=attention_mask
|
147 |
-
)
|
148 |
-
text_features = text_outputs.last_hidden_state
|
149 |
-
|
150 |
-
# Get enhanced features
|
151 |
-
feature_embedding = self.feature_encoder(
|
152 |
-
numerical_features,
|
153 |
-
categorical_features
|
154 |
-
)
|
155 |
-
|
156 |
-
# Apply metadata attention
|
157 |
-
text_features = self.metadata_attention(
|
158 |
-
text_features,
|
159 |
-
feature_embedding
|
160 |
-
)
|
161 |
-
|
162 |
-
# Combine features
|
163 |
-
combined_embedding = torch.cat([text_features, feature_embedding], dim=1)
|
164 |
-
|
165 |
-
# Get logits
|
166 |
-
logits = self.classifier(combined_embedding)
|
167 |
-
|
168 |
-
# Prepare output dict
|
169 |
-
outputs = {"logits": logits}
|
170 |
-
|
171 |
-
# Calculate loss if labels provided
|
172 |
-
if labels is not None:
|
173 |
-
outputs["loss"] = self.criterion(logits, labels)
|
174 |
-
|
175 |
-
return outputs
|
176 |
-
|
177 |
-
except Exception as e:
|
178 |
-
logger.error(f"Error in forward pass: {str(e)}")
|
179 |
-
raise
|
180 |
-
|
181 |
-
class ModelWrapper:
|
182 |
-
"""Wrapper for model management and inference"""
|
183 |
-
def __init__(self,
|
184 |
-
num_classes: int,
|
185 |
-
base_model_name: str = "google/mobilebert-uncased",
|
186 |
-
num_numerical_features: int = 10,
|
187 |
-
categorical_feature_dims: Optional[Dict[str, int]] = None,
|
188 |
-
device: Optional[torch.device] = None):
|
189 |
-
|
190 |
-
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
191 |
-
logger.info(f"Using device: {self.device}")
|
192 |
-
|
193 |
-
try:
|
194 |
-
# Initialize tokenizer and model
|
195 |
-
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
196 |
-
self.model = ClimateDisinformationModel(
|
197 |
-
num_classes=num_classes,
|
198 |
-
base_model_name=base_model_name,
|
199 |
-
num_numerical_features=num_numerical_features,
|
200 |
-
categorical_feature_dims=categorical_feature_dims,
|
201 |
-
device=self.device
|
202 |
-
)
|
203 |
-
except Exception as e:
|
204 |
-
logger.error(f"Error initializing ModelWrapper: {str(e)}")
|
205 |
-
raise
|
206 |
-
|
207 |
-
def train_step(self,
|
208 |
-
batch: Dict[str, torch.Tensor],
|
209 |
-
optimizer: torch.optim.Optimizer) -> Tuple[float, torch.Tensor]:
|
210 |
-
"""Single training step"""
|
211 |
-
try:
|
212 |
-
# Set model to training mode
|
213 |
-
self.model.train()
|
214 |
-
|
215 |
-
# Zero gradients
|
216 |
-
optimizer.zero_grad()
|
217 |
-
|
218 |
-
# Forward pass
|
219 |
-
outputs = self.model(**batch)
|
220 |
-
loss = outputs["loss"]
|
221 |
-
|
222 |
-
# Backward pass
|
223 |
-
loss.backward()
|
224 |
-
|
225 |
-
# Optimizer step
|
226 |
-
optimizer.step()
|
227 |
-
|
228 |
-
return loss, outputs["logits"]
|
229 |
-
|
230 |
-
except Exception as e:
|
231 |
-
logger.error(f"Error in training step: {str(e)}")
|
232 |
-
raise
|
233 |
-
|
234 |
-
def predict(self,
|
235 |
-
texts: List[str],
|
236 |
-
numerical_features: np.ndarray,
|
237 |
-
categorical_features: Dict[str, np.ndarray]) -> np.ndarray:
|
238 |
-
"""Batch prediction"""
|
239 |
-
try:
|
240 |
-
self.model.eval()
|
241 |
-
|
242 |
-
# Prepare inputs
|
243 |
-
inputs = self.tokenizer(
|
244 |
-
texts,
|
245 |
-
return_tensors="pt",
|
246 |
-
max_length=128,
|
247 |
-
truncation=True,
|
248 |
-
padding="max_length"
|
249 |
-
)
|
250 |
-
|
251 |
-
# Move inputs to device
|
252 |
-
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
253 |
-
num_features = torch.FloatTensor(numerical_features).to(self.device)
|
254 |
-
cat_features = {
|
255 |
-
k: torch.LongTensor(v).to(self.device)
|
256 |
-
for k, v in categorical_features.items()
|
257 |
-
}
|
258 |
-
|
259 |
-
# Get predictions
|
260 |
-
with torch.no_grad():
|
261 |
-
outputs = self.model(
|
262 |
-
input_ids=inputs["input_ids"],
|
263 |
-
attention_mask=inputs["attention_mask"],
|
264 |
-
numerical_features=num_features,
|
265 |
-
categorical_features=cat_features
|
266 |
-
)
|
267 |
-
|
268 |
-
return F.softmax(outputs["logits"], dim=1).cpu().numpy()
|
269 |
-
|
270 |
-
except Exception as e:
|
271 |
-
logger.error(f"Error in prediction: {str(e)}")
|
272 |
-
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|