YixuanWang commited on
Commit
ab7af96
·
verified ·
1 Parent(s): 28fe915

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -83
app.py CHANGED
@@ -8,6 +8,8 @@ from typing import List, Dict, Tuple
8
  from dataclasses import dataclass
9
  from pathlib import Path
10
  import logging
 
 
11
 
12
  # Configure logging
13
  logging.basicConfig(level=logging.INFO)
@@ -29,7 +31,7 @@ class TweetPreprocessor:
29
  """Load and validate the dataset."""
30
  try:
31
  data = pd.read_csv(data_path)
32
- required_columns = {'Text', 'Retweets', 'Likes'}
33
  if not required_columns.issubset(data.columns):
34
  raise ValueError(f"Missing required columns: {required_columns - set(data.columns)}")
35
  return data
@@ -37,12 +39,40 @@ class TweetPreprocessor:
37
  logger.error(f"Error loading data: {e}")
38
  raise
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def calculate_metrics(self) -> pd.DataFrame:
41
  """Calculate sentiment and popularity metrics."""
42
- self.data['Sentiment'] = self.data['Text'].apply(self._get_sentiment)
 
 
 
 
 
43
  self.data['Popularity'] = self._normalize_popularity()
 
 
 
44
  return self.data
45
 
 
 
 
 
 
 
 
46
  @staticmethod
47
  def _get_sentiment(text: str) -> float:
48
  """Calculate sentiment polarity for a text."""
@@ -55,45 +85,7 @@ class TweetPreprocessor:
55
  def _normalize_popularity(self) -> pd.Series:
56
  """Normalize popularity scores using min-max scaling."""
57
  popularity = self.data['Retweets'] + self.data['Likes']
58
- return (popularity - popularity.mean()) / (popularity.std() or 1)
59
-
60
- class FakeNewsClassifier:
61
- def __init__(self, model_name: str):
62
- """Initialize the fake news classifier."""
63
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
- self.model_name = model_name
65
- self.model, self.tokenizer = self._load_model()
66
-
67
- def _load_model(self) -> Tuple[AutoModelForSequenceClassification, AutoTokenizer]:
68
- """Load the model and tokenizer."""
69
- try:
70
- tokenizer = AutoTokenizer.from_pretrained(self.model_name)
71
- model = AutoModelForSequenceClassification.from_pretrained(self.model_name).to(self.device)
72
- return model, tokenizer
73
- except Exception as e:
74
- logger.error(f"Error loading model: {e}")
75
- raise
76
-
77
- @torch.no_grad()
78
- def predict_batch(self, texts: List[str], batch_size: int = 100) -> np.ndarray:
79
- """Predict fake news probability for a batch of texts."""
80
- predictions = []
81
-
82
- for i in range(0, len(texts), batch_size):
83
- batch_texts = texts[i:i + batch_size]
84
- inputs = self.tokenizer(
85
- batch_texts,
86
- return_tensors="pt",
87
- padding=True,
88
- truncation=True,
89
- max_length=128
90
- ).to(self.device)
91
-
92
- outputs = self.model(**inputs)
93
- batch_predictions = outputs.logits.argmax(dim=1).cpu().numpy()
94
- predictions.extend(batch_predictions)
95
-
96
- return np.array(predictions)
97
 
98
  class RecommendationSystem:
99
  def __init__(self, data_path: Path, model_name: str):
@@ -106,13 +98,13 @@ class RecommendationSystem:
106
  def setup_system(self):
107
  """Set up the recommendation system."""
108
  self.data = self.preprocessor.calculate_metrics()
109
- predictions = self.classifier.predict_batch(self.data['Text'].tolist())
110
  self.data['Credibility'] = [1 if pred == 1 else -1 for pred in predictions]
111
 
112
- def get_recommendations(self, weights: RecommendationWeights, num_recommendations: int = 10) -> str:
113
  """Get tweet recommendations based on weights."""
114
  if not self._validate_weights(weights):
115
- return "Error: Invalid weights provided"
116
 
117
  normalized_weights = self._normalize_weights(weights)
118
 
@@ -120,7 +112,7 @@ class RecommendationSystem:
120
  self.data['Credibility'] * normalized_weights.visibility +
121
  self.data['Sentiment'] * normalized_weights.sentiment +
122
  self.data['Popularity'] * normalized_weights.popularity
123
- )
124
 
125
  top_recommendations = (
126
  self.data.nlargest(100, 'Final_Score')
@@ -129,49 +121,92 @@ class RecommendationSystem:
129
 
130
  return self._format_recommendations(top_recommendations)
131
 
132
- @staticmethod
133
- def _validate_weights(weights: RecommendationWeights) -> bool:
134
- """Validate that weights are non-negative."""
135
- return all(getattr(weights, field) >= 0 for field in weights.__dataclass_fields__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  @staticmethod
138
- def _normalize_weights(weights: RecommendationWeights) -> RecommendationWeights:
139
- """Normalize weights to sum to 1."""
140
- total = weights.visibility + weights.sentiment + weights.popularity
141
- if total == 0:
142
- return RecommendationWeights(1/3, 1/3, 1/3)
143
- return RecommendationWeights(
144
- visibility=weights.visibility / total,
145
- sentiment=weights.sentiment / total,
146
- popularity=weights.popularity / total
147
- )
148
 
149
  @staticmethod
150
- def _format_recommendations(recommendations: pd.DataFrame) -> str:
151
- """Format recommendations for display."""
152
- return "\n\n".join(
153
- f"**Tweet**: {row['Text']}\n**Score**: {row['Final_Score']:.2f}"
154
- for _, row in recommendations.iterrows()
155
- )
 
 
156
 
157
  def create_gradio_interface(recommendation_system: RecommendationSystem) -> gr.Interface:
158
  """Create and configure the Gradio interface."""
159
- def predict_and_recommend(visibility_weight, sentiment_weight, popularity_weight):
160
- weights = RecommendationWeights(visibility_weight, sentiment_weight, popularity_weight)
161
- return recommendation_system.get_recommendations(weights)
162
-
163
- return gr.Interface(
164
- fn=predict_and_recommend,
165
- inputs=[
166
- gr.Slider(0, 1, 0.5, label="Visibility Weight"),
167
- gr.Slider(0, 1, 0.3, label="Sentiment Weight"),
168
- gr.Slider(0, 1, 0.2, label="Popularity Weight")
169
- ],
170
- outputs="markdown",
171
- title="Enhanced Fake News Recommendation System",
172
- description="Adjust weights to receive customized tweet recommendations based on visibility, sentiment, and popularity.",
173
- theme="default"
174
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  def main():
177
  """Main function to run the application."""
@@ -180,8 +215,8 @@ def main():
180
  data_path=Path('twitter_dataset.csv'),
181
  model_name="hamzab/roberta-fake-news-classification"
182
  )
183
- iface = create_gradio_interface(recommendation_system)
184
- iface.launch()
185
  except Exception as e:
186
  logger.error(f"Application failed to start: {e}")
187
  raise
 
8
  from dataclasses import dataclass
9
  from pathlib import Path
10
  import logging
11
+ import re
12
+ from datetime import datetime
13
 
14
  # Configure logging
15
  logging.basicConfig(level=logging.INFO)
 
31
  """Load and validate the dataset."""
32
  try:
33
  data = pd.read_csv(data_path)
34
+ required_columns = {'Text', 'Retweets', 'Likes', 'Timestamp'} # 添加时间戳列
35
  if not required_columns.issubset(data.columns):
36
  raise ValueError(f"Missing required columns: {required_columns - set(data.columns)}")
37
  return data
 
39
  logger.error(f"Error loading data: {e}")
40
  raise
41
 
42
+ def _clean_text(self, text: str) -> str:
43
+ """清理文本内容,移除无意义的内容"""
44
+ if pd.isna(text) or len(str(text).strip()) < 10: # 排除过短或空的文本
45
+ return ""
46
+
47
+ # 移除URL
48
+ text = re.sub(r'http\S+|www.\S+', '', str(text))
49
+ # 移除特殊字符
50
+ text = re.sub(r'[^\w\s]', '', text)
51
+ # 移除多余空格
52
+ text = ' '.join(text.split())
53
+ return text
54
+
55
  def calculate_metrics(self) -> pd.DataFrame:
56
  """Calculate sentiment and popularity metrics."""
57
+ # 清理文本
58
+ self.data['Clean_Text'] = self.data['Text'].apply(self._clean_text)
59
+ # 过滤掉无效的文本
60
+ self.data = self.data[self.data['Clean_Text'].str.len() > 0]
61
+
62
+ self.data['Sentiment'] = self.data['Clean_Text'].apply(self._get_sentiment)
63
  self.data['Popularity'] = self._normalize_popularity()
64
+
65
+ # 添加时间衰减因子
66
+ self.data['Time_Weight'] = self._calculate_time_weight()
67
  return self.data
68
 
69
+ def _calculate_time_weight(self) -> pd.Series:
70
+ """计算时间权重,越新的内容权重越高"""
71
+ current_time = datetime.now()
72
+ self.data['Timestamp'] = pd.to_datetime(self.data['Timestamp'])
73
+ time_diff = (current_time - self.data['Timestamp']).dt.total_seconds()
74
+ return np.exp(-time_diff / (7 * 24 * 3600)) # 7天的衰减周期
75
+
76
  @staticmethod
77
  def _get_sentiment(text: str) -> float:
78
  """Calculate sentiment polarity for a text."""
 
85
  def _normalize_popularity(self) -> pd.Series:
86
  """Normalize popularity scores using min-max scaling."""
87
  popularity = self.data['Retweets'] + self.data['Likes']
88
+ return (popularity - popularity.min()) / (popularity.max() - popularity.min() + 1e-6)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  class RecommendationSystem:
91
  def __init__(self, data_path: Path, model_name: str):
 
98
  def setup_system(self):
99
  """Set up the recommendation system."""
100
  self.data = self.preprocessor.calculate_metrics()
101
+ predictions = self.classifier.predict_batch(self.data['Clean_Text'].tolist())
102
  self.data['Credibility'] = [1 if pred == 1 else -1 for pred in predictions]
103
 
104
+ def get_recommendations(self, weights: RecommendationWeights, num_recommendations: int = 10) -> Dict:
105
  """Get tweet recommendations based on weights."""
106
  if not self._validate_weights(weights):
107
+ return {"error": "Invalid weights provided"}
108
 
109
  normalized_weights = self._normalize_weights(weights)
110
 
 
112
  self.data['Credibility'] * normalized_weights.visibility +
113
  self.data['Sentiment'] * normalized_weights.sentiment +
114
  self.data['Popularity'] * normalized_weights.popularity
115
+ ) * self.data['Time_Weight'] # 考虑时间因素
116
 
117
  top_recommendations = (
118
  self.data.nlargest(100, 'Final_Score')
 
121
 
122
  return self._format_recommendations(top_recommendations)
123
 
124
+ def _format_recommendations(self, recommendations: pd.DataFrame) -> Dict:
125
+ """Format recommendations for display."""
126
+ formatted_results = []
127
+ for _, row in recommendations.iterrows():
128
+ score_details = {
129
+ "总分": f"{row['Final_Score']:.2f}",
130
+ "可信度": "可信" if row['Credibility'] > 0 else "存疑",
131
+ "情感倾向": self._get_sentiment_label(row['Sentiment']),
132
+ "热度": f"{row['Popularity']:.2f}",
133
+ "互动数": f"点赞 {row['Likes']} · 转发 {row['Retweets']}"
134
+ }
135
+
136
+ formatted_results.append({
137
+ "text": row['Clean_Text'],
138
+ "scores": score_details,
139
+ "timestamp": row['Timestamp'].strftime("%Y-%m-%d %H:%M")
140
+ })
141
+
142
+ return {
143
+ "recommendations": formatted_results,
144
+ "score_explanation": self._get_score_explanation()
145
+ }
146
 
147
  @staticmethod
148
+ def _get_sentiment_label(sentiment_score: float) -> str:
149
+ """Convert sentiment score to human-readable label."""
150
+ if sentiment_score > 0.3:
151
+ return "积极"
152
+ elif sentiment_score < -0.3:
153
+ return "消极"
154
+ return "中性"
 
 
 
155
 
156
  @staticmethod
157
+ def _get_score_explanation() -> Dict[str, str]:
158
+ """Provide explanation for different score components."""
159
+ return {
160
+ "可信度": "基于机器学习模型对内容可信度的评估",
161
+ "情感倾向": "文本的情感倾向分析结果",
162
+ "热度": "根据点赞和转发数量计算的归一化热度分数",
163
+ "时间权重": "考虑内容时效性的权重因子"
164
+ }
165
 
166
  def create_gradio_interface(recommendation_system: RecommendationSystem) -> gr.Interface:
167
  """Create and configure the Gradio interface."""
168
+ with gr.Blocks(theme=gr.themes.Soft()) as interface:
169
+ gr.Markdown("""
170
+ # 推文推荐系统
171
+
172
+ 这个系统通过多个维度来为您推荐高质量的推文:
173
+ - **可信度**: 评估内容的可靠性
174
+ - **情感倾向**: 分析文本的情感色彩
175
+ - **热度**: 考虑内容的受欢迎程度
176
+ - **时效性**: 优先推荐较新的内容
177
+ """)
178
+
179
+ with gr.Row():
180
+ with gr.Column(scale=1):
181
+ visibility_weight = gr.Slider(
182
+ 0, 1, 0.5,
183
+ label="可信度权重",
184
+ info="调整对内容可信度的重视程度"
185
+ )
186
+ sentiment_weight = gr.Slider(
187
+ 0, 1, 0.3,
188
+ label="情感倾向权重",
189
+ info="调整对情感倾向的重视程度"
190
+ )
191
+ popularity_weight = gr.Slider(
192
+ 0, 1, 0.2,
193
+ label="热度权重",
194
+ info="调整对内容热度的重视程度"
195
+ )
196
+ submit_btn = gr.Button("获取推荐", variant="primary")
197
+
198
+ with gr.Column(scale=2):
199
+ output = gr.JSON(label="推荐结果")
200
+
201
+ submit_btn.click(
202
+ fn=lambda v, s, p: recommendation_system.get_recommendations(
203
+ RecommendationWeights(v, s, p)
204
+ ),
205
+ inputs=[visibility_weight, sentiment_weight, popularity_weight],
206
+ outputs=output
207
+ )
208
+
209
+ return interface
210
 
211
  def main():
212
  """Main function to run the application."""
 
215
  data_path=Path('twitter_dataset.csv'),
216
  model_name="hamzab/roberta-fake-news-classification"
217
  )
218
+ interface = create_gradio_interface(recommendation_system)
219
+ interface.launch()
220
  except Exception as e:
221
  logger.error(f"Application failed to start: {e}")
222
  raise