YixuanWang commited on
Commit
37190a8
1 Parent(s): 596f852

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -55
app.py CHANGED
@@ -1,13 +1,13 @@
1
  import gradio as gr
2
  import pandas as pd
3
  import numpy as np
 
 
4
  from textblob import TextBlob
5
  from typing import List, Dict, Tuple
6
  from dataclasses import dataclass
7
  from pathlib import Path
8
  import logging
9
- import re
10
- from datetime import datetime
11
 
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
@@ -20,12 +20,18 @@ class RecommendationWeights:
20
 
21
  class TweetPreprocessor:
22
  def __init__(self, data_path: Path):
23
- """Initialize the preprocessor with data path."""
24
  self.data = self._load_data(data_path)
 
 
 
 
 
 
 
 
25
 
26
  @staticmethod
27
  def _load_data(data_path: Path) -> pd.DataFrame:
28
- """Load and validate the dataset."""
29
  try:
30
  data = pd.read_csv(data_path)
31
  required_columns = {'Text', 'Retweets', 'Likes'}
@@ -36,39 +42,28 @@ class TweetPreprocessor:
36
  logger.error(f"Error loading data: {e}")
37
  raise
38
 
39
- def _clean_text(self, text: str) -> str:
40
- """Clean text content."""
41
- if pd.isna(text) or len(str(text).strip()) < 10:
42
- return ""
43
-
44
- text = re.sub(r'http\S+|www.\S+', '', str(text))
45
- text = re.sub(r'[^\w\s]', '', text)
46
- text = ' '.join(text.split())
47
- return text
48
-
49
  def calculate_metrics(self) -> pd.DataFrame:
50
- """Calculate all metrics for tweets."""
51
- self.data['Clean_Text'] = self.data['Text'].apply(self._clean_text)
52
- self.data = self.data[self.data['Clean_Text'].str.len() > 0]
53
 
54
- self.data['Sentiment'] = self.data['Clean_Text'].apply(self._get_sentiment)
55
- self.data['Popularity'] = self._normalize_popularity()
 
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  return self.data
58
-
59
- @staticmethod
60
- def _get_sentiment(text: str) -> float:
61
- """Calculate sentiment polarity for a text."""
62
- try:
63
- return TextBlob(str(text)).sentiment.polarity
64
- except Exception as e:
65
- logger.warning(f"Error calculating sentiment: {e}")
66
- return 0.0
67
-
68
- def _normalize_popularity(self) -> pd.Series:
69
- """Normalize popularity scores."""
70
- popularity = self.data['Retweets'] + self.data['Likes']
71
- return (popularity - popularity.min()) / (popularity.max() - popularity.min() + 1e-6)
72
 
73
  class RecommendationSystem:
74
  def __init__(self, data_path: Path):
@@ -77,36 +72,28 @@ class RecommendationSystem:
77
  self.setup_system()
78
 
79
  def setup_system(self):
80
- """Initialize the system with preprocessed data."""
81
  self.data = self.preprocessor.calculate_metrics()
82
 
83
- def recalculate_scores(self, weights: RecommendationWeights):
84
- """Recalculate scores based on new weights."""
 
 
85
  normalized_weights = self._normalize_weights(weights)
86
 
87
- self.data['Credibility'] = np.random.choice([0, 1], size=len(self.data), p=[0.3, 0.7])
88
-
89
  self.data['Final_Score'] = (
90
  self.data['Credibility'] * normalized_weights.visibility +
91
  self.data['Sentiment'] * normalized_weights.sentiment +
92
  self.data['Popularity'] * normalized_weights.popularity
93
  )
94
 
95
- def get_recommendations(self, weights: RecommendationWeights, num_recommendations: int = 10) -> Dict:
96
- """Get tweet recommendations based on weights."""
97
- if not self._validate_weights(weights):
98
- return {"error": "Invalid weights provided"}
99
-
100
- self.recalculate_scores(weights)
101
-
102
  top_recommendations = (
103
- self.data.nlargest(num_recommendations, 'Final_Score')
 
104
  )
105
 
106
  return self._format_recommendations(top_recommendations)
107
 
108
  def _format_recommendations(self, recommendations: pd.DataFrame) -> Dict:
109
- """Format recommendations for display."""
110
  formatted_results = []
111
  for _, row in recommendations.iterrows():
112
  score_details = {
@@ -118,7 +105,7 @@ class RecommendationSystem:
118
  }
119
 
120
  formatted_results.append({
121
- "text": row['Clean_Text'],
122
  "scores": score_details
123
  })
124
 
@@ -129,7 +116,6 @@ class RecommendationSystem:
129
 
130
  @staticmethod
131
  def _get_sentiment_label(sentiment_score: float) -> str:
132
- """Convert sentiment score to label."""
133
  if sentiment_score > 0.3:
134
  return "Positive"
135
  elif sentiment_score < -0.3:
@@ -138,12 +124,10 @@ class RecommendationSystem:
138
 
139
  @staticmethod
140
  def _validate_weights(weights: RecommendationWeights) -> bool:
141
- """Validate that weights are non-negative."""
142
  return all(getattr(weights, field) >= 0 for field in weights.__dataclass_fields__)
143
 
144
  @staticmethod
145
  def _normalize_weights(weights: RecommendationWeights) -> RecommendationWeights:
146
- """Normalize weights to sum to 1."""
147
  total = weights.visibility + weights.sentiment + weights.popularity
148
  if total == 0:
149
  return RecommendationWeights(1/3, 1/3, 1/3)
@@ -155,7 +139,6 @@ class RecommendationSystem:
155
 
156
  @staticmethod
157
  def _get_score_explanation() -> Dict[str, str]:
158
- """Provide explanation for different score components."""
159
  return {
160
  "Credibility": "Content reliability assessment",
161
  "Sentiment": "Text emotional analysis result",
@@ -163,7 +146,6 @@ class RecommendationSystem:
163
  }
164
 
165
  def create_gradio_interface(recommendation_system: RecommendationSystem) -> gr.Interface:
166
- """Create and configure the Gradio interface."""
167
  with gr.Blocks(theme=gr.themes.Soft()) as interface:
168
  gr.Markdown("""
169
  # Tweet Recommendation System
@@ -224,7 +206,6 @@ def create_gradio_interface(recommendation_system: RecommendationSystem) -> gr.I
224
  return html
225
 
226
  def get_recommendations_with_weights(v, s, p):
227
- """Get recommendations with current weights."""
228
  weights = RecommendationWeights(v, s, p)
229
  return format_recommendations(recommendation_system.get_recommendations(weights))
230
 
@@ -237,7 +218,6 @@ def create_gradio_interface(recommendation_system: RecommendationSystem) -> gr.I
237
  return interface
238
 
239
  def main():
240
- """Main function to run the application."""
241
  try:
242
  recommendation_system = RecommendationSystem(
243
  data_path=Path('twitter_dataset.csv')
 
1
  import gradio as gr
2
  import pandas as pd
3
  import numpy as np
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
  from textblob import TextBlob
7
  from typing import List, Dict, Tuple
8
  from dataclasses import dataclass
9
  from pathlib import Path
10
  import logging
 
 
11
 
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
 
20
 
21
  class TweetPreprocessor:
22
  def __init__(self, data_path: Path):
 
23
  self.data = self._load_data(data_path)
24
+ self.model_name = "hamzab/roberta-fake-news-classification"
25
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ self.model, self.tokenizer = self._load_model()
27
+
28
+ def _load_model(self):
29
+ tokenizer = AutoTokenizer.from_pretrained(self.model_name)
30
+ model = AutoModelForSequenceClassification.from_pretrained(self.model_name).to(self.device)
31
+ return model, tokenizer
32
 
33
  @staticmethod
34
  def _load_data(data_path: Path) -> pd.DataFrame:
 
35
  try:
36
  data = pd.read_csv(data_path)
37
  required_columns = {'Text', 'Retweets', 'Likes'}
 
42
  logger.error(f"Error loading data: {e}")
43
  raise
44
 
 
 
 
 
 
 
 
 
 
 
45
  def calculate_metrics(self) -> pd.DataFrame:
46
+ # Calculate sentiment
47
+ self.data['Sentiment'] = self.data['Text'].apply(lambda x: TextBlob(x).sentiment.polarity)
 
48
 
49
+ # Calculate popularity
50
+ self.data['Popularity'] = self.data['Retweets'] + self.data['Likes']
51
+ self.data['Popularity'] = (self.data['Popularity'] - self.data['Popularity'].mean()) / self.data['Popularity'].std()
52
+ self.data['Popularity'] = self.data['Popularity'] / self.data['Popularity'].abs().max()
53
 
54
+ # Calculate credibility using fake news model
55
+ batch_size = 100
56
+ predictions = []
57
+ for i in range(0, len(self.data), batch_size):
58
+ batch = self.data['Text'][i:i + batch_size].tolist()
59
+ inputs = self.tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=128)
60
+ inputs = {key: val.to(self.device) for key, val in inputs.items()}
61
+ with torch.no_grad():
62
+ outputs = self.model(**inputs)
63
+ predictions.extend(outputs.logits.argmax(dim=1).cpu().numpy())
64
+
65
+ self.data['Credibility'] = [1 if pred == 1 else -1 for pred in predictions]
66
  return self.data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  class RecommendationSystem:
69
  def __init__(self, data_path: Path):
 
72
  self.setup_system()
73
 
74
  def setup_system(self):
 
75
  self.data = self.preprocessor.calculate_metrics()
76
 
77
+ def get_recommendations(self, weights: RecommendationWeights, num_recommendations: int = 10) -> Dict:
78
+ if not self._validate_weights(weights):
79
+ return {"error": "Invalid weights provided"}
80
+
81
  normalized_weights = self._normalize_weights(weights)
82
 
 
 
83
  self.data['Final_Score'] = (
84
  self.data['Credibility'] * normalized_weights.visibility +
85
  self.data['Sentiment'] * normalized_weights.sentiment +
86
  self.data['Popularity'] * normalized_weights.popularity
87
  )
88
 
 
 
 
 
 
 
 
89
  top_recommendations = (
90
+ self.data.nlargest(100, 'Final_Score')
91
+ .sample(num_recommendations)
92
  )
93
 
94
  return self._format_recommendations(top_recommendations)
95
 
96
  def _format_recommendations(self, recommendations: pd.DataFrame) -> Dict:
 
97
  formatted_results = []
98
  for _, row in recommendations.iterrows():
99
  score_details = {
 
105
  }
106
 
107
  formatted_results.append({
108
+ "text": row['Text'],
109
  "scores": score_details
110
  })
111
 
 
116
 
117
  @staticmethod
118
  def _get_sentiment_label(sentiment_score: float) -> str:
 
119
  if sentiment_score > 0.3:
120
  return "Positive"
121
  elif sentiment_score < -0.3:
 
124
 
125
  @staticmethod
126
  def _validate_weights(weights: RecommendationWeights) -> bool:
 
127
  return all(getattr(weights, field) >= 0 for field in weights.__dataclass_fields__)
128
 
129
  @staticmethod
130
  def _normalize_weights(weights: RecommendationWeights) -> RecommendationWeights:
 
131
  total = weights.visibility + weights.sentiment + weights.popularity
132
  if total == 0:
133
  return RecommendationWeights(1/3, 1/3, 1/3)
 
139
 
140
  @staticmethod
141
  def _get_score_explanation() -> Dict[str, str]:
 
142
  return {
143
  "Credibility": "Content reliability assessment",
144
  "Sentiment": "Text emotional analysis result",
 
146
  }
147
 
148
  def create_gradio_interface(recommendation_system: RecommendationSystem) -> gr.Interface:
 
149
  with gr.Blocks(theme=gr.themes.Soft()) as interface:
150
  gr.Markdown("""
151
  # Tweet Recommendation System
 
206
  return html
207
 
208
  def get_recommendations_with_weights(v, s, p):
 
209
  weights = RecommendationWeights(v, s, p)
210
  return format_recommendations(recommendation_system.get_recommendations(weights))
211
 
 
218
  return interface
219
 
220
  def main():
 
221
  try:
222
  recommendation_system = RecommendationSystem(
223
  data_path=Path('twitter_dataset.csv')