YixuanWang commited on
Commit
f0cb7f7
·
verified ·
1 Parent(s): fefd091

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -108
app.py CHANGED
@@ -1,17 +1,12 @@
1
  import gradio as gr
2
  import pandas as pd
3
  import numpy as np
4
- import torch
5
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
  from textblob import TextBlob
7
  from typing import List, Dict, Tuple
8
  from dataclasses import dataclass
9
  from pathlib import Path
10
  import logging
11
- import re
12
- from datetime import datetime
13
 
14
- # Configure logging
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
@@ -23,15 +18,13 @@ class RecommendationWeights:
23
 
24
  class TweetPreprocessor:
25
  def __init__(self, data_path: Path):
26
- """Initialize the preprocessor with data path."""
27
  self.data = self._load_data(data_path)
28
 
29
  @staticmethod
30
  def _load_data(data_path: Path) -> pd.DataFrame:
31
- """Load and validate the dataset."""
32
  try:
33
  data = pd.read_csv(data_path)
34
- required_columns = {'Text', 'Retweets', 'Likes', 'Timestamp'} # 添加时间戳列
35
  if not required_columns.issubset(data.columns):
36
  raise ValueError(f"Missing required columns: {required_columns - set(data.columns)}")
37
  return data
@@ -39,43 +32,14 @@ class TweetPreprocessor:
39
  logger.error(f"Error loading data: {e}")
40
  raise
41
 
42
- def _clean_text(self, text: str) -> str:
43
- """清理文本内容,移除无意义的内容"""
44
- if pd.isna(text) or len(str(text).strip()) < 10: # 排除过短或空的文本
45
- return ""
46
-
47
- # 移除URL
48
- text = re.sub(r'http\S+|www.\S+', '', str(text))
49
- # 移除特殊字符
50
- text = re.sub(r'[^\w\s]', '', text)
51
- # 移除多余空格
52
- text = ' '.join(text.split())
53
- return text
54
-
55
  def calculate_metrics(self) -> pd.DataFrame:
56
- """Calculate sentiment and popularity metrics."""
57
- # 清理文本
58
- self.data['Clean_Text'] = self.data['Text'].apply(self._clean_text)
59
- # 过滤掉无效的文本
60
- self.data = self.data[self.data['Clean_Text'].str.len() > 0]
61
-
62
- self.data['Sentiment'] = self.data['Clean_Text'].apply(self._get_sentiment)
63
  self.data['Popularity'] = self._normalize_popularity()
64
-
65
- # 添加时间衰减因子
66
- self.data['Time_Weight'] = self._calculate_time_weight()
67
  return self.data
68
 
69
- def _calculate_time_weight(self) -> pd.Series:
70
- """计算时间权重,越新的内容权重越高"""
71
- current_time = datetime.now()
72
- self.data['Timestamp'] = pd.to_datetime(self.data['Timestamp'])
73
- time_diff = (current_time - self.data['Timestamp']).dt.total_seconds()
74
- return np.exp(-time_diff / (7 * 24 * 3600)) # 7天的衰减周期
75
-
76
  @staticmethod
77
  def _get_sentiment(text: str) -> float:
78
- """Calculate sentiment polarity for a text."""
79
  try:
80
  return TextBlob(str(text)).sentiment.polarity
81
  except Exception as e:
@@ -83,63 +47,19 @@ class TweetPreprocessor:
83
  return 0.0
84
 
85
  def _normalize_popularity(self) -> pd.Series:
86
- """Normalize popularity scores using min-max scaling."""
87
  popularity = self.data['Retweets'] + self.data['Likes']
88
  return (popularity - popularity.min()) / (popularity.max() - popularity.min() + 1e-6)
89
- class FakeNewsClassifier:
90
- def __init__(self, model_name: str):
91
- """Initialize the fake news classifier."""
92
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
93
- self.model_name = model_name
94
- self.model, self.tokenizer = self._load_model()
95
-
96
- def _load_model(self) -> Tuple[AutoModelForSequenceClassification, AutoTokenizer]:
97
- """Load the model and tokenizer."""
98
- try:
99
- tokenizer = AutoTokenizer.from_pretrained(self.model_name)
100
- model = AutoModelForSequenceClassification.from_pretrained(self.model_name).to(self.device)
101
- return model, tokenizer
102
- except Exception as e:
103
- logger.error(f"Error loading model: {e}")
104
- raise
105
-
106
- @torch.no_grad()
107
- def predict_batch(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
108
- """Predict fake news probability for a batch of texts."""
109
- predictions = []
110
-
111
- for i in range(0, len(texts), batch_size):
112
- batch_texts = texts[i:i + batch_size]
113
- inputs = self.tokenizer(
114
- batch_texts,
115
- return_tensors="pt",
116
- padding=True,
117
- truncation=True,
118
- max_length=128
119
- ).to(self.device)
120
-
121
- outputs = self.model(**inputs)
122
- batch_predictions = outputs.logits.argmax(dim=1).cpu().numpy()
123
- predictions.extend(batch_predictions)
124
-
125
- return np.array(predictions)
126
 
127
  class RecommendationSystem:
128
- def __init__(self, data_path: Path, model_name: str):
129
- """Initialize the recommendation system."""
130
  self.preprocessor = TweetPreprocessor(data_path)
131
- self.classifier = FakeNewsClassifier(model_name)
132
  self.data = None
133
  self.setup_system()
134
 
135
  def setup_system(self):
136
- """Set up the recommendation system."""
137
  self.data = self.preprocessor.calculate_metrics()
138
- predictions = self.classifier.predict_batch(self.data['Clean_Text'].tolist())
139
- self.data['Credibility'] = [1 if pred == 1 else -1 for pred in predictions]
140
 
141
  def get_recommendations(self, weights: RecommendationWeights, num_recommendations: int = 10) -> Dict:
142
- """Get tweet recommendations based on weights."""
143
  if not self._validate_weights(weights):
144
  return {"error": "Invalid weights provided"}
145
 
@@ -149,17 +69,15 @@ class RecommendationSystem:
149
  self.data['Credibility'] * normalized_weights.visibility +
150
  self.data['Sentiment'] * normalized_weights.sentiment +
151
  self.data['Popularity'] * normalized_weights.popularity
152
- ) * self.data['Time_Weight'] # 考虑时间因素
153
 
154
  top_recommendations = (
155
- self.data.nlargest(100, 'Final_Score')
156
- .sample(num_recommendations)
157
  )
158
 
159
  return self._format_recommendations(top_recommendations)
160
 
161
  def _format_recommendations(self, recommendations: pd.DataFrame) -> Dict:
162
- """Format recommendations for display."""
163
  formatted_results = []
164
  for _, row in recommendations.iterrows():
165
  score_details = {
@@ -171,9 +89,8 @@ class RecommendationSystem:
171
  }
172
 
173
  formatted_results.append({
174
- "text": row['Clean_Text'],
175
- "scores": score_details,
176
- "timestamp": row['Timestamp'].strftime("%Y-%m-%d %H:%M")
177
  })
178
 
179
  return {
@@ -183,34 +100,40 @@ class RecommendationSystem:
183
 
184
  @staticmethod
185
  def _get_sentiment_label(sentiment_score: float) -> str:
186
- """Convert sentiment score to human-readable label."""
187
  if sentiment_score > 0.3:
188
  return "积极"
189
  elif sentiment_score < -0.3:
190
  return "消极"
191
  return "中性"
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  @staticmethod
194
  def _get_score_explanation() -> Dict[str, str]:
195
- """Provide explanation for different score components."""
196
  return {
197
- "可信度": "基于机器学习模型对内容可信度的评估",
198
- "情感倾向": "文本的情感倾向分析结果",
199
- "热度": "根据点赞和转发数量计算的归一化热度分数",
200
- "时间权重": "考虑内容时效性的权重因子"
201
  }
202
 
203
  def create_gradio_interface(recommendation_system: RecommendationSystem) -> gr.Interface:
204
- """Create and configure the Gradio interface."""
205
  with gr.Blocks(theme=gr.themes.Soft()) as interface:
206
  gr.Markdown("""
207
  # 推文推荐系统
208
-
209
- 这个系统通过多个维度来为您推荐高质量的推文:
210
- - **可信度**: 评估内容的可靠性
211
- - **情感倾向**: 分析文本的情感色彩
212
- - **热度**: 考虑内容的受欢迎程度
213
- - **时效性**: 优先推荐较新的内容
214
  """)
215
 
216
  with gr.Row():
@@ -246,11 +169,9 @@ def create_gradio_interface(recommendation_system: RecommendationSystem) -> gr.I
246
  return interface
247
 
248
  def main():
249
- """Main function to run the application."""
250
  try:
251
  recommendation_system = RecommendationSystem(
252
- data_path=Path('twitter_dataset.csv'),
253
- model_name="hamzab/roberta-fake-news-classification"
254
  )
255
  interface = create_gradio_interface(recommendation_system)
256
  interface.launch()
@@ -259,4 +180,4 @@ def main():
259
  raise
260
 
261
  if __name__ == "__main__":
262
- main()
 
1
  import gradio as gr
2
  import pandas as pd
3
  import numpy as np
 
 
4
  from textblob import TextBlob
5
  from typing import List, Dict, Tuple
6
  from dataclasses import dataclass
7
  from pathlib import Path
8
  import logging
 
 
9
 
 
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
 
18
 
19
  class TweetPreprocessor:
20
  def __init__(self, data_path: Path):
 
21
  self.data = self._load_data(data_path)
22
 
23
  @staticmethod
24
  def _load_data(data_path: Path) -> pd.DataFrame:
 
25
  try:
26
  data = pd.read_csv(data_path)
27
+ required_columns = {'Text', 'Retweets', 'Likes'}
28
  if not required_columns.issubset(data.columns):
29
  raise ValueError(f"Missing required columns: {required_columns - set(data.columns)}")
30
  return data
 
32
  logger.error(f"Error loading data: {e}")
33
  raise
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def calculate_metrics(self) -> pd.DataFrame:
36
+ self.data['Sentiment'] = self.data['Text'].apply(self._get_sentiment)
 
 
 
 
 
 
37
  self.data['Popularity'] = self._normalize_popularity()
38
+ self.data['Credibility'] = np.random.choice([0, 1], size=len(self.data), p=[0.3, 0.7])
 
 
39
  return self.data
40
 
 
 
 
 
 
 
 
41
  @staticmethod
42
  def _get_sentiment(text: str) -> float:
 
43
  try:
44
  return TextBlob(str(text)).sentiment.polarity
45
  except Exception as e:
 
47
  return 0.0
48
 
49
  def _normalize_popularity(self) -> pd.Series:
 
50
  popularity = self.data['Retweets'] + self.data['Likes']
51
  return (popularity - popularity.min()) / (popularity.max() - popularity.min() + 1e-6)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  class RecommendationSystem:
54
+ def __init__(self, data_path: Path):
 
55
  self.preprocessor = TweetPreprocessor(data_path)
 
56
  self.data = None
57
  self.setup_system()
58
 
59
  def setup_system(self):
 
60
  self.data = self.preprocessor.calculate_metrics()
 
 
61
 
62
  def get_recommendations(self, weights: RecommendationWeights, num_recommendations: int = 10) -> Dict:
 
63
  if not self._validate_weights(weights):
64
  return {"error": "Invalid weights provided"}
65
 
 
69
  self.data['Credibility'] * normalized_weights.visibility +
70
  self.data['Sentiment'] * normalized_weights.sentiment +
71
  self.data['Popularity'] * normalized_weights.popularity
72
+ )
73
 
74
  top_recommendations = (
75
+ self.data.nlargest(num_recommendations, 'Final_Score')
 
76
  )
77
 
78
  return self._format_recommendations(top_recommendations)
79
 
80
  def _format_recommendations(self, recommendations: pd.DataFrame) -> Dict:
 
81
  formatted_results = []
82
  for _, row in recommendations.iterrows():
83
  score_details = {
 
89
  }
90
 
91
  formatted_results.append({
92
+ "text": row['Text'],
93
+ "scores": score_details
 
94
  })
95
 
96
  return {
 
100
 
101
  @staticmethod
102
  def _get_sentiment_label(sentiment_score: float) -> str:
 
103
  if sentiment_score > 0.3:
104
  return "积极"
105
  elif sentiment_score < -0.3:
106
  return "消极"
107
  return "中性"
108
 
109
+ @staticmethod
110
+ def _validate_weights(weights: RecommendationWeights) -> bool:
111
+ return all(getattr(weights, field) >= 0 for field in weights.__dataclass_fields__)
112
+
113
+ @staticmethod
114
+ def _normalize_weights(weights: RecommendationWeights) -> RecommendationWeights:
115
+ total = weights.visibility + weights.sentiment + weights.popularity
116
+ if total == 0:
117
+ return RecommendationWeights(1/3, 1/3, 1/3)
118
+ return RecommendationWeights(
119
+ visibility=weights.visibility / total,
120
+ sentiment=weights.sentiment / total,
121
+ popularity=weights.popularity / total
122
+ )
123
+
124
  @staticmethod
125
  def _get_score_explanation() -> Dict[str, str]:
 
126
  return {
127
+ "可信度": "内容可信度评估",
128
+ "情感倾向": "文本的情感分析结果",
129
+ "热度": "基于点赞和转发的热度分数"
 
130
  }
131
 
132
  def create_gradio_interface(recommendation_system: RecommendationSystem) -> gr.Interface:
 
133
  with gr.Blocks(theme=gr.themes.Soft()) as interface:
134
  gr.Markdown("""
135
  # 推文推荐系统
136
+ 调整下方的权重来获取个性化推荐:
 
 
 
 
 
137
  """)
138
 
139
  with gr.Row():
 
169
  return interface
170
 
171
  def main():
 
172
  try:
173
  recommendation_system = RecommendationSystem(
174
+ data_path=Path('twitter_dataset.csv')
 
175
  )
176
  interface = create_gradio_interface(recommendation_system)
177
  interface.launch()
 
180
  raise
181
 
182
  if __name__ == "__main__":
183
+ main()