Zen0 commited on
Commit
36d97aa
·
verified ·
1 Parent(s): 8ea6ec0

Delete climate_model.py

Browse files
Files changed (1) hide show
  1. climate_model.py +0 -272
climate_model.py DELETED
@@ -1,272 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from transformers import AutoModel, AutoTokenizer
5
- from typing import Dict, List, Optional, Tuple
6
- import numpy as np
7
- import logging
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
- class MetadataAttention(nn.Module):
12
- """Attention mechanism for combining text and metadata features"""
13
- def __init__(self, text_dim: int, metadata_dim: int):
14
- super().__init__()
15
- self.text_linear = nn.Linear(text_dim, 64)
16
- self.metadata_linear = nn.Linear(metadata_dim, 64)
17
- self.attention = nn.Sequential(
18
- nn.Linear(128, 64),
19
- nn.Tanh(),
20
- nn.Linear(64, 1),
21
- nn.Softmax(dim=1)
22
- )
23
-
24
- def forward(self, text_features: torch.Tensor, metadata_features: torch.Tensor) -> torch.Tensor:
25
- text_proj = self.text_linear(text_features)
26
- meta_proj = self.metadata_linear(metadata_features)
27
- meta_proj = meta_proj.unsqueeze(1).expand(-1, text_proj.size(1), -1)
28
- combined = torch.cat([text_proj, meta_proj], dim=-1)
29
- weights = self.attention(combined)
30
- weighted_sum = (text_features * weights).sum(dim=1)
31
- return weighted_sum
32
-
33
- class FeatureEncoder(nn.Module):
34
- """Encodes numerical and categorical features"""
35
- def __init__(self, num_numerical_features: int, categorical_feature_dims: Dict[str, int]):
36
- super().__init__()
37
-
38
- # Numerical features
39
- self.numerical_bn = nn.BatchNorm1d(num_numerical_features)
40
- self.numerical_encoder = nn.Sequential(
41
- nn.Linear(num_numerical_features, 64),
42
- nn.LayerNorm(64),
43
- nn.ReLU(),
44
- nn.Dropout(0.2)
45
- )
46
-
47
- # Categorical features
48
- self.categorical_encoders = nn.ModuleDict()
49
- self.categorical_dims = {}
50
- for feature_name, dim in categorical_feature_dims.items():
51
- self.categorical_encoders[feature_name] = nn.Sequential(
52
- nn.Embedding(dim, 32),
53
- nn.Linear(32, 32),
54
- nn.ReLU()
55
- )
56
- self.categorical_dims[feature_name] = dim
57
-
58
- self.output_dim = 64 + 32 * len(categorical_feature_dims)
59
-
60
- def forward(self, numerical_features: torch.Tensor,
61
- categorical_features: Dict[str, torch.Tensor]) -> torch.Tensor:
62
- x_num = self.numerical_bn(numerical_features)
63
- x_num = self.numerical_encoder(x_num)
64
-
65
- x_cat_list = []
66
- for feature_name, encoder in self.categorical_encoders.items():
67
- if feature_name in categorical_features:
68
- x_cat = encoder(categorical_features[feature_name])
69
- x_cat_list.append(x_cat)
70
-
71
- if x_cat_list:
72
- x_cat = torch.cat(x_cat_list, dim=1)
73
- return torch.cat([x_num, x_cat], dim=1)
74
- return x_num
75
-
76
- class ClimateDisinformationModel(nn.Module):
77
- """Model for climate disinformation classification"""
78
- def __init__(self,
79
- num_classes: int,
80
- base_model_name: str = "google/mobilebert-uncased",
81
- num_numerical_features: int = 10,
82
- categorical_feature_dims: Optional[Dict[str, int]] = None,
83
- device: Optional[torch.device] = None):
84
- super().__init__()
85
-
86
- self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
87
-
88
- if categorical_feature_dims is None:
89
- categorical_feature_dims = {}
90
-
91
- try:
92
- # Text encoder
93
- self.text_encoder = AutoModel.from_pretrained(base_model_name)
94
- hidden_size = self.text_encoder.config.hidden_size
95
-
96
- # Feature processing
97
- self.feature_encoder = FeatureEncoder(
98
- num_numerical_features,
99
- categorical_feature_dims
100
- )
101
-
102
- # Metadata attention
103
- self.metadata_attention = MetadataAttention(
104
- text_dim=hidden_size,
105
- metadata_dim=self.feature_encoder.output_dim
106
- )
107
-
108
- # Classifier
109
- combined_dim = hidden_size + self.feature_encoder.output_dim
110
- self.classifier = nn.Sequential(
111
- nn.Linear(combined_dim, 256),
112
- nn.LayerNorm(256),
113
- nn.ReLU(),
114
- nn.Dropout(0.1),
115
- nn.Linear(256, num_classes)
116
- )
117
-
118
- # Loss function (will be set by set_class_weights)
119
- self.criterion = nn.CrossEntropyLoss()
120
-
121
- # Move to device
122
- self.to(self.device)
123
-
124
- except Exception as e:
125
- logger.error(f"Error initializing model: {str(e)}")
126
- raise
127
-
128
- def set_class_weights(self, class_weights: torch.Tensor):
129
- """Set class weights for loss function"""
130
- try:
131
- self.criterion = nn.CrossEntropyLoss(weight=class_weights.to(self.device))
132
- except Exception as e:
133
- logger.error(f"Error setting class weights: {str(e)}")
134
- raise
135
-
136
- def forward(self,
137
- input_ids: torch.Tensor,
138
- attention_mask: torch.Tensor,
139
- numerical_features: torch.Tensor,
140
- categorical_features: Dict[str, torch.Tensor],
141
- labels: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
142
- try:
143
- # Get text features
144
- text_outputs = self.text_encoder(
145
- input_ids=input_ids,
146
- attention_mask=attention_mask
147
- )
148
- text_features = text_outputs.last_hidden_state
149
-
150
- # Get enhanced features
151
- feature_embedding = self.feature_encoder(
152
- numerical_features,
153
- categorical_features
154
- )
155
-
156
- # Apply metadata attention
157
- text_features = self.metadata_attention(
158
- text_features,
159
- feature_embedding
160
- )
161
-
162
- # Combine features
163
- combined_embedding = torch.cat([text_features, feature_embedding], dim=1)
164
-
165
- # Get logits
166
- logits = self.classifier(combined_embedding)
167
-
168
- # Prepare output dict
169
- outputs = {"logits": logits}
170
-
171
- # Calculate loss if labels provided
172
- if labels is not None:
173
- outputs["loss"] = self.criterion(logits, labels)
174
-
175
- return outputs
176
-
177
- except Exception as e:
178
- logger.error(f"Error in forward pass: {str(e)}")
179
- raise
180
-
181
- class ModelWrapper:
182
- """Wrapper for model management and inference"""
183
- def __init__(self,
184
- num_classes: int,
185
- base_model_name: str = "google/mobilebert-uncased",
186
- num_numerical_features: int = 10,
187
- categorical_feature_dims: Optional[Dict[str, int]] = None,
188
- device: Optional[torch.device] = None):
189
-
190
- self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
191
- logger.info(f"Using device: {self.device}")
192
-
193
- try:
194
- # Initialize tokenizer and model
195
- self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
196
- self.model = ClimateDisinformationModel(
197
- num_classes=num_classes,
198
- base_model_name=base_model_name,
199
- num_numerical_features=num_numerical_features,
200
- categorical_feature_dims=categorical_feature_dims,
201
- device=self.device
202
- )
203
- except Exception as e:
204
- logger.error(f"Error initializing ModelWrapper: {str(e)}")
205
- raise
206
-
207
- def train_step(self,
208
- batch: Dict[str, torch.Tensor],
209
- optimizer: torch.optim.Optimizer) -> Tuple[float, torch.Tensor]:
210
- """Single training step"""
211
- try:
212
- # Set model to training mode
213
- self.model.train()
214
-
215
- # Zero gradients
216
- optimizer.zero_grad()
217
-
218
- # Forward pass
219
- outputs = self.model(**batch)
220
- loss = outputs["loss"]
221
-
222
- # Backward pass
223
- loss.backward()
224
-
225
- # Optimizer step
226
- optimizer.step()
227
-
228
- return loss, outputs["logits"]
229
-
230
- except Exception as e:
231
- logger.error(f"Error in training step: {str(e)}")
232
- raise
233
-
234
- def predict(self,
235
- texts: List[str],
236
- numerical_features: np.ndarray,
237
- categorical_features: Dict[str, np.ndarray]) -> np.ndarray:
238
- """Batch prediction"""
239
- try:
240
- self.model.eval()
241
-
242
- # Prepare inputs
243
- inputs = self.tokenizer(
244
- texts,
245
- return_tensors="pt",
246
- max_length=128,
247
- truncation=True,
248
- padding="max_length"
249
- )
250
-
251
- # Move inputs to device
252
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
253
- num_features = torch.FloatTensor(numerical_features).to(self.device)
254
- cat_features = {
255
- k: torch.LongTensor(v).to(self.device)
256
- for k, v in categorical_features.items()
257
- }
258
-
259
- # Get predictions
260
- with torch.no_grad():
261
- outputs = self.model(
262
- input_ids=inputs["input_ids"],
263
- attention_mask=inputs["attention_mask"],
264
- numerical_features=num_features,
265
- categorical_features=cat_features
266
- )
267
-
268
- return F.softmax(outputs["logits"], dim=1).cpu().numpy()
269
-
270
- except Exception as e:
271
- logger.error(f"Error in prediction: {str(e)}")
272
- raise