YixuanWang commited on
Commit
28fe915
·
verified ·
1 Parent(s): 7154a46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -66
app.py CHANGED
@@ -4,71 +4,187 @@ import numpy as np
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
  from textblob import TextBlob
 
 
 
 
7
 
8
- # Load the dataset from the local file
9
- data = pd.read_csv('twitter_dataset.csv')
10
-
11
- # Calculate sentiment polarity and popularity
12
- data['Sentiment'] = data['Text'].apply(lambda x: TextBlob(x).sentiment.polarity)
13
- data['Popularity'] = data['Retweets'] + data['Likes']
14
- data['Popularity'] = (data['Popularity'] - data['Popularity'].mean()) / data['Popularity'].std()
15
- data['Popularity'] = data['Popularity'] / data['Popularity'].abs().max()
16
-
17
- # Load the fake news classification model
18
- model_name = "hamzab/roberta-fake-news-classification"
19
- tokenizer = AutoTokenizer.from_pretrained(model_name)
20
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
21
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
- model = model.to(device)
23
-
24
- # Process tweets in batches to avoid memory issues
25
- batch_size = 100
26
- predictions = []
27
- for i in range(0, len(data), batch_size):
28
- batch = data['Text'][i:i + batch_size].tolist()
29
- inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=128)
30
- inputs = {key: val.to(device) for key, val in inputs.items()}
31
- with torch.no_grad():
32
- outputs = model(**inputs)
33
- predictions.extend(outputs.logits.argmax(dim=1).cpu().numpy())
34
-
35
- data['Fake_News_Prediction'] = predictions
36
- data['Credibility'] = data['Fake_News_Prediction'].apply(lambda x: 1 if x == 1 else -1)
37
-
38
- # Define the prediction and recommendation function
39
- def predict_and_recommend(visibility_weight, sentiment_weight, popularity_weight):
40
- # Adjust weights and calculate the final score
41
- total_weight = visibility_weight + sentiment_weight + popularity_weight
42
- visibility_weight /= total_weight
43
- sentiment_weight /= total_weight
44
- popularity_weight /= total_weight
45
-
46
- # Update final visibility score with user-defined weights
47
- data['User_Final_Visibility_Score'] = (
48
- data['Credibility'] * visibility_weight +
49
- data['Sentiment'] * sentiment_weight +
50
- data['Popularity'] * popularity_weight
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  )
52
- # Sort and randomly sample 10 recommendations
53
- top_100_data = data.nlargest(100, 'User_Final_Visibility_Score')
54
- recommended_data = top_100_data.sample(10)
55
-
56
- # Format output with empty lines between tweets
57
- output = "\n\n".join(f"**Tweet**: {row['Text']}\n**Score**: {row['User_Final_Visibility_Score']:.2f}"
58
- for _, row in recommended_data.iterrows())
59
- return output
60
-
61
- # Set up Gradio interface
62
- iface = gr.Interface(
63
- fn=predict_and_recommend,
64
- inputs=[
65
- gr.Slider(0, 1, 0.5, label="Visibility Weight"),
66
- gr.Slider(0, 1, 0.3, label="Sentiment Weight"),
67
- gr.Slider(0, 1, 0.2, label="Popularity Weight")
68
- ],
69
- outputs="markdown",
70
- title="Customizable Fake News Recommendation System",
71
- description="Adjust weights to receive customized tweet recommendations based on visibility, sentiment, and popularity."
72
- )
73
-
74
- iface.launch()
 
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
  from textblob import TextBlob
7
+ from typing import List, Dict, Tuple
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ import logging
11
 
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ @dataclass
17
+ class RecommendationWeights:
18
+ visibility: float
19
+ sentiment: float
20
+ popularity: float
21
+
22
+ class TweetPreprocessor:
23
+ def __init__(self, data_path: Path):
24
+ """Initialize the preprocessor with data path."""
25
+ self.data = self._load_data(data_path)
26
+
27
+ @staticmethod
28
+ def _load_data(data_path: Path) -> pd.DataFrame:
29
+ """Load and validate the dataset."""
30
+ try:
31
+ data = pd.read_csv(data_path)
32
+ required_columns = {'Text', 'Retweets', 'Likes'}
33
+ if not required_columns.issubset(data.columns):
34
+ raise ValueError(f"Missing required columns: {required_columns - set(data.columns)}")
35
+ return data
36
+ except Exception as e:
37
+ logger.error(f"Error loading data: {e}")
38
+ raise
39
+
40
+ def calculate_metrics(self) -> pd.DataFrame:
41
+ """Calculate sentiment and popularity metrics."""
42
+ self.data['Sentiment'] = self.data['Text'].apply(self._get_sentiment)
43
+ self.data['Popularity'] = self._normalize_popularity()
44
+ return self.data
45
+
46
+ @staticmethod
47
+ def _get_sentiment(text: str) -> float:
48
+ """Calculate sentiment polarity for a text."""
49
+ try:
50
+ return TextBlob(str(text)).sentiment.polarity
51
+ except Exception as e:
52
+ logger.warning(f"Error calculating sentiment: {e}")
53
+ return 0.0
54
+
55
+ def _normalize_popularity(self) -> pd.Series:
56
+ """Normalize popularity scores using min-max scaling."""
57
+ popularity = self.data['Retweets'] + self.data['Likes']
58
+ return (popularity - popularity.mean()) / (popularity.std() or 1)
59
+
60
+ class FakeNewsClassifier:
61
+ def __init__(self, model_name: str):
62
+ """Initialize the fake news classifier."""
63
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
+ self.model_name = model_name
65
+ self.model, self.tokenizer = self._load_model()
66
+
67
+ def _load_model(self) -> Tuple[AutoModelForSequenceClassification, AutoTokenizer]:
68
+ """Load the model and tokenizer."""
69
+ try:
70
+ tokenizer = AutoTokenizer.from_pretrained(self.model_name)
71
+ model = AutoModelForSequenceClassification.from_pretrained(self.model_name).to(self.device)
72
+ return model, tokenizer
73
+ except Exception as e:
74
+ logger.error(f"Error loading model: {e}")
75
+ raise
76
+
77
+ @torch.no_grad()
78
+ def predict_batch(self, texts: List[str], batch_size: int = 100) -> np.ndarray:
79
+ """Predict fake news probability for a batch of texts."""
80
+ predictions = []
81
+
82
+ for i in range(0, len(texts), batch_size):
83
+ batch_texts = texts[i:i + batch_size]
84
+ inputs = self.tokenizer(
85
+ batch_texts,
86
+ return_tensors="pt",
87
+ padding=True,
88
+ truncation=True,
89
+ max_length=128
90
+ ).to(self.device)
91
+
92
+ outputs = self.model(**inputs)
93
+ batch_predictions = outputs.logits.argmax(dim=1).cpu().numpy()
94
+ predictions.extend(batch_predictions)
95
+
96
+ return np.array(predictions)
97
+
98
+ class RecommendationSystem:
99
+ def __init__(self, data_path: Path, model_name: str):
100
+ """Initialize the recommendation system."""
101
+ self.preprocessor = TweetPreprocessor(data_path)
102
+ self.classifier = FakeNewsClassifier(model_name)
103
+ self.data = None
104
+ self.setup_system()
105
+
106
+ def setup_system(self):
107
+ """Set up the recommendation system."""
108
+ self.data = self.preprocessor.calculate_metrics()
109
+ predictions = self.classifier.predict_batch(self.data['Text'].tolist())
110
+ self.data['Credibility'] = [1 if pred == 1 else -1 for pred in predictions]
111
+
112
+ def get_recommendations(self, weights: RecommendationWeights, num_recommendations: int = 10) -> str:
113
+ """Get tweet recommendations based on weights."""
114
+ if not self._validate_weights(weights):
115
+ return "Error: Invalid weights provided"
116
+
117
+ normalized_weights = self._normalize_weights(weights)
118
+
119
+ self.data['Final_Score'] = (
120
+ self.data['Credibility'] * normalized_weights.visibility +
121
+ self.data['Sentiment'] * normalized_weights.sentiment +
122
+ self.data['Popularity'] * normalized_weights.popularity
123
+ )
124
+
125
+ top_recommendations = (
126
+ self.data.nlargest(100, 'Final_Score')
127
+ .sample(num_recommendations)
128
+ )
129
+
130
+ return self._format_recommendations(top_recommendations)
131
+
132
+ @staticmethod
133
+ def _validate_weights(weights: RecommendationWeights) -> bool:
134
+ """Validate that weights are non-negative."""
135
+ return all(getattr(weights, field) >= 0 for field in weights.__dataclass_fields__)
136
+
137
+ @staticmethod
138
+ def _normalize_weights(weights: RecommendationWeights) -> RecommendationWeights:
139
+ """Normalize weights to sum to 1."""
140
+ total = weights.visibility + weights.sentiment + weights.popularity
141
+ if total == 0:
142
+ return RecommendationWeights(1/3, 1/3, 1/3)
143
+ return RecommendationWeights(
144
+ visibility=weights.visibility / total,
145
+ sentiment=weights.sentiment / total,
146
+ popularity=weights.popularity / total
147
+ )
148
+
149
+ @staticmethod
150
+ def _format_recommendations(recommendations: pd.DataFrame) -> str:
151
+ """Format recommendations for display."""
152
+ return "\n\n".join(
153
+ f"**Tweet**: {row['Text']}\n**Score**: {row['Final_Score']:.2f}"
154
+ for _, row in recommendations.iterrows()
155
+ )
156
+
157
+ def create_gradio_interface(recommendation_system: RecommendationSystem) -> gr.Interface:
158
+ """Create and configure the Gradio interface."""
159
+ def predict_and_recommend(visibility_weight, sentiment_weight, popularity_weight):
160
+ weights = RecommendationWeights(visibility_weight, sentiment_weight, popularity_weight)
161
+ return recommendation_system.get_recommendations(weights)
162
+
163
+ return gr.Interface(
164
+ fn=predict_and_recommend,
165
+ inputs=[
166
+ gr.Slider(0, 1, 0.5, label="Visibility Weight"),
167
+ gr.Slider(0, 1, 0.3, label="Sentiment Weight"),
168
+ gr.Slider(0, 1, 0.2, label="Popularity Weight")
169
+ ],
170
+ outputs="markdown",
171
+ title="Enhanced Fake News Recommendation System",
172
+ description="Adjust weights to receive customized tweet recommendations based on visibility, sentiment, and popularity.",
173
+ theme="default"
174
  )
175
+
176
+ def main():
177
+ """Main function to run the application."""
178
+ try:
179
+ recommendation_system = RecommendationSystem(
180
+ data_path=Path('twitter_dataset.csv'),
181
+ model_name="hamzab/roberta-fake-news-classification"
182
+ )
183
+ iface = create_gradio_interface(recommendation_system)
184
+ iface.launch()
185
+ except Exception as e:
186
+ logger.error(f"Application failed to start: {e}")
187
+ raise
188
+
189
+ if __name__ == "__main__":
190
+ main()