kusa04 commited on
Commit
188fc65
·
verified ·
1 Parent(s): 0a33dc9

Update function.py

Browse files
Files changed (1) hide show
  1. function.py +117 -2
function.py CHANGED
@@ -1,3 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # ------------------ Sentiment Analysis Functions ------------------------#
2
  def split_text_by_token_limit(text, tokenizer, max_tokens):
3
  tokens = tokenizer.encode(text, add_special_tokens=False)
@@ -8,13 +108,15 @@ def split_text_by_token_limit(text, tokenizer, max_tokens):
8
  chunks.append(chunk_text)
9
  return chunks
10
 
11
- def safe_sentiment(text):
 
12
  try:
13
  result = sentiment_pipeline(text)[0]
14
  except Exception as e:
15
  result = None
16
  return result
17
 
 
18
  def analyze_detail(text, tokenizer, sentiment_pipeline, max_tokens):
19
  text = preprocess_text(text)
20
  chunks = split_text_by_token_limit(text, tokenizer, max_tokens)
@@ -32,8 +134,21 @@ def analyze_detail(text, tokenizer, sentiment_pipeline, max_tokens):
32
  final_score = scores[final_label]
33
  return {"label": final_label, "score": final_score}
34
 
 
 
35
  def preprocess_text(text):
36
  # Replace URLs and user mentions
37
  text = re.sub(r'http\S+', 'http', text)
38
  text = re.sub(r'@\w+', '@user', text)
39
- return text
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+ import matplotlib.pyplot as plt
3
+ import pandas as pd
4
+ import praw # Reddit's API
5
+ import re # Regular expression module
6
+ import streamlit as st
7
+ import time
8
+ import numpy as np
9
+ from wordcloud import WordCloud
10
+ from transformers import (
11
+ pipeline,
12
+ AutoTokenizer,
13
+ AutoModelForSequenceClassification,
14
+ AutoModelForTokenClassification,
15
+ TokenClassificationPipeline
16
+ )
17
+ from transformers.pipelines import AggregationStrategy
18
+
19
+
20
+ # ---------- Cached function for loading the model pipelines ----------
21
+ @st.cache_resource
22
+ def load_sentiment_pipeline(): # sentiment pipeline
23
+ tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
24
+ model = AutoModelForSequenceClassification.from_pretrained(
25
+ "cardiffnlp/twitter-roberta-base-sentiment-latest",
26
+ use_auth_token=st.secrets["hugging_face_token"]
27
+ )
28
+ sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, device=0) # -1 to 0
29
+ max_tokens = tokenizer.model_max_length
30
+
31
+ if max_tokens > 10000:
32
+ max_tokens = 512
33
+ return sentiment_pipeline, tokenizer, max_tokens
34
+
35
+
36
+ # class for keyword extraction
37
+ @st.cache_resource
38
+ class KeyphraseExtractionPipeline(TokenClassificationPipeline):
39
+ def __init__(self, model, *args, **kwargs):
40
+ super().__init__(
41
+ model=AutoModelForTokenClassification.from_pretrained(model),
42
+ tokenizer=AutoTokenizer.from_pretrained(model),
43
+ *args,
44
+ **kwargs
45
+ )
46
+
47
+ def postprocess(self, all_outputs):
48
+ results = super().postprocess(
49
+ all_outputs=all_outputs,
50
+ aggregation_strategy=AggregationStrategy.SIMPLE,
51
+ )
52
+ return np.unique([result.get("word").strip() for result in results])
53
+
54
+ def keyword_extractor():
55
+ model_name = "ml6team/keyphrase-extraction-kbir-inspec"
56
+ extractor = KeyphraseExtractionPipeline(model=model_name)
57
+ return extractor
58
+
59
+
60
+ # Function to normalize text by replacing multiple spaces/newlines with a single space
61
+ def normalize_text(text):
62
+ if not isinstance(text, str):
63
+ return ""
64
+ return re.sub(r'\s+', ' ', text).strip()
65
+
66
+ # ---------- Cached function for scraping Reddit data ----------
67
+ @st.cache_data(show_spinner=False)
68
+ def scrape_reddit_data(search_query, total_limit):
69
+ # Retrieve API credentials from st.secrets
70
+ reddit = praw.Reddit(
71
+ client_id=st.secrets["reddit_client_id"],
72
+ client_secret=st.secrets["reddit_client_secret"],
73
+ user_agent=st.secrets["reddit_user_agent"]
74
+ )
75
+ subreddit = reddit.subreddit("all")
76
+ posts_data = []
77
+ # Iterate over submissions based on the search query and limit
78
+ for i, submission in enumerate(subreddit.search(search_query, sort="relevance", limit=total_limit)):
79
+ # No UI updates here as caching does not allow live progress updates
80
+ if submission.title and submission.selftext:
81
+ posts_data.append([
82
+ submission.title,
83
+ submission.url,
84
+ submission.created_utc,
85
+ submission.selftext,
86
+ ])
87
+ time.sleep(0.25)
88
+
89
+ df = pd.DataFrame(posts_data, columns=["Title", "URL", "Date", "Detail"])
90
+
91
+ for col in ["Title", "Detail"]:
92
+ df[col] = df[col].apply(normalize_text)
93
+
94
+ # Filter out rows with empty Title or Detail
95
+ df = df[(df["Title"] != "") & (df["Detail"] != "")]
96
+ df['Date'] = pd.to_datetime(df['Date'], unit='s')
97
+ df = df.sort_values(by="Date", ascending=True).reset_index(drop=True)
98
+ return df
99
+
100
+
101
  # ------------------ Sentiment Analysis Functions ------------------------#
102
  def split_text_by_token_limit(text, tokenizer, max_tokens):
103
  tokens = tokenizer.encode(text, add_special_tokens=False)
 
108
  chunks.append(chunk_text)
109
  return chunks
110
 
111
+
112
+ def safe_sentiment(sentiment_pipeline, text):
113
  try:
114
  result = sentiment_pipeline(text)[0]
115
  except Exception as e:
116
  result = None
117
  return result
118
 
119
+
120
  def analyze_detail(text, tokenizer, sentiment_pipeline, max_tokens):
121
  text = preprocess_text(text)
122
  chunks = split_text_by_token_limit(text, tokenizer, max_tokens)
 
134
  final_score = scores[final_label]
135
  return {"label": final_label, "score": final_score}
136
 
137
+
138
+
139
  def preprocess_text(text):
140
  # Replace URLs and user mentions
141
  text = re.sub(r'http\S+', 'http', text)
142
  text = re.sub(r'@\w+', '@user', text)
143
+ return text
144
+
145
+
146
+ # def keyword_extraction(text):
147
+ # try:
148
+ # extractor = keyword_extractor()
149
+ # result = extractor(text)
150
+ # except Exception as e:
151
+ # # Optionally, log the error: print(f"Error processing text: {e}")
152
+ # result = None
153
+
154
+ # return result