Spaces:
Sleeping
Sleeping
Update function.py
Browse files- function.py +117 -2
function.py
CHANGED
@@ -1,3 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# ------------------ Sentiment Analysis Functions ------------------------#
|
2 |
def split_text_by_token_limit(text, tokenizer, max_tokens):
|
3 |
tokens = tokenizer.encode(text, add_special_tokens=False)
|
@@ -8,13 +108,15 @@ def split_text_by_token_limit(text, tokenizer, max_tokens):
|
|
8 |
chunks.append(chunk_text)
|
9 |
return chunks
|
10 |
|
11 |
-
|
|
|
12 |
try:
|
13 |
result = sentiment_pipeline(text)[0]
|
14 |
except Exception as e:
|
15 |
result = None
|
16 |
return result
|
17 |
|
|
|
18 |
def analyze_detail(text, tokenizer, sentiment_pipeline, max_tokens):
|
19 |
text = preprocess_text(text)
|
20 |
chunks = split_text_by_token_limit(text, tokenizer, max_tokens)
|
@@ -32,8 +134,21 @@ def analyze_detail(text, tokenizer, sentiment_pipeline, max_tokens):
|
|
32 |
final_score = scores[final_label]
|
33 |
return {"label": final_label, "score": final_score}
|
34 |
|
|
|
|
|
35 |
def preprocess_text(text):
|
36 |
# Replace URLs and user mentions
|
37 |
text = re.sub(r'http\S+', 'http', text)
|
38 |
text = re.sub(r'@\w+', '@user', text)
|
39 |
-
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import Counter
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import pandas as pd
|
4 |
+
import praw # Reddit's API
|
5 |
+
import re # Regular expression module
|
6 |
+
import streamlit as st
|
7 |
+
import time
|
8 |
+
import numpy as np
|
9 |
+
from wordcloud import WordCloud
|
10 |
+
from transformers import (
|
11 |
+
pipeline,
|
12 |
+
AutoTokenizer,
|
13 |
+
AutoModelForSequenceClassification,
|
14 |
+
AutoModelForTokenClassification,
|
15 |
+
TokenClassificationPipeline
|
16 |
+
)
|
17 |
+
from transformers.pipelines import AggregationStrategy
|
18 |
+
|
19 |
+
|
20 |
+
# ---------- Cached function for loading the model pipelines ----------
|
21 |
+
@st.cache_resource
|
22 |
+
def load_sentiment_pipeline(): # sentiment pipeline
|
23 |
+
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
|
24 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
25 |
+
"cardiffnlp/twitter-roberta-base-sentiment-latest",
|
26 |
+
use_auth_token=st.secrets["hugging_face_token"]
|
27 |
+
)
|
28 |
+
sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, device=0) # -1 to 0
|
29 |
+
max_tokens = tokenizer.model_max_length
|
30 |
+
|
31 |
+
if max_tokens > 10000:
|
32 |
+
max_tokens = 512
|
33 |
+
return sentiment_pipeline, tokenizer, max_tokens
|
34 |
+
|
35 |
+
|
36 |
+
# class for keyword extraction
|
37 |
+
@st.cache_resource
|
38 |
+
class KeyphraseExtractionPipeline(TokenClassificationPipeline):
|
39 |
+
def __init__(self, model, *args, **kwargs):
|
40 |
+
super().__init__(
|
41 |
+
model=AutoModelForTokenClassification.from_pretrained(model),
|
42 |
+
tokenizer=AutoTokenizer.from_pretrained(model),
|
43 |
+
*args,
|
44 |
+
**kwargs
|
45 |
+
)
|
46 |
+
|
47 |
+
def postprocess(self, all_outputs):
|
48 |
+
results = super().postprocess(
|
49 |
+
all_outputs=all_outputs,
|
50 |
+
aggregation_strategy=AggregationStrategy.SIMPLE,
|
51 |
+
)
|
52 |
+
return np.unique([result.get("word").strip() for result in results])
|
53 |
+
|
54 |
+
def keyword_extractor():
|
55 |
+
model_name = "ml6team/keyphrase-extraction-kbir-inspec"
|
56 |
+
extractor = KeyphraseExtractionPipeline(model=model_name)
|
57 |
+
return extractor
|
58 |
+
|
59 |
+
|
60 |
+
# Function to normalize text by replacing multiple spaces/newlines with a single space
|
61 |
+
def normalize_text(text):
|
62 |
+
if not isinstance(text, str):
|
63 |
+
return ""
|
64 |
+
return re.sub(r'\s+', ' ', text).strip()
|
65 |
+
|
66 |
+
# ---------- Cached function for scraping Reddit data ----------
|
67 |
+
@st.cache_data(show_spinner=False)
|
68 |
+
def scrape_reddit_data(search_query, total_limit):
|
69 |
+
# Retrieve API credentials from st.secrets
|
70 |
+
reddit = praw.Reddit(
|
71 |
+
client_id=st.secrets["reddit_client_id"],
|
72 |
+
client_secret=st.secrets["reddit_client_secret"],
|
73 |
+
user_agent=st.secrets["reddit_user_agent"]
|
74 |
+
)
|
75 |
+
subreddit = reddit.subreddit("all")
|
76 |
+
posts_data = []
|
77 |
+
# Iterate over submissions based on the search query and limit
|
78 |
+
for i, submission in enumerate(subreddit.search(search_query, sort="relevance", limit=total_limit)):
|
79 |
+
# No UI updates here as caching does not allow live progress updates
|
80 |
+
if submission.title and submission.selftext:
|
81 |
+
posts_data.append([
|
82 |
+
submission.title,
|
83 |
+
submission.url,
|
84 |
+
submission.created_utc,
|
85 |
+
submission.selftext,
|
86 |
+
])
|
87 |
+
time.sleep(0.25)
|
88 |
+
|
89 |
+
df = pd.DataFrame(posts_data, columns=["Title", "URL", "Date", "Detail"])
|
90 |
+
|
91 |
+
for col in ["Title", "Detail"]:
|
92 |
+
df[col] = df[col].apply(normalize_text)
|
93 |
+
|
94 |
+
# Filter out rows with empty Title or Detail
|
95 |
+
df = df[(df["Title"] != "") & (df["Detail"] != "")]
|
96 |
+
df['Date'] = pd.to_datetime(df['Date'], unit='s')
|
97 |
+
df = df.sort_values(by="Date", ascending=True).reset_index(drop=True)
|
98 |
+
return df
|
99 |
+
|
100 |
+
|
101 |
# ------------------ Sentiment Analysis Functions ------------------------#
|
102 |
def split_text_by_token_limit(text, tokenizer, max_tokens):
|
103 |
tokens = tokenizer.encode(text, add_special_tokens=False)
|
|
|
108 |
chunks.append(chunk_text)
|
109 |
return chunks
|
110 |
|
111 |
+
|
112 |
+
def safe_sentiment(sentiment_pipeline, text):
|
113 |
try:
|
114 |
result = sentiment_pipeline(text)[0]
|
115 |
except Exception as e:
|
116 |
result = None
|
117 |
return result
|
118 |
|
119 |
+
|
120 |
def analyze_detail(text, tokenizer, sentiment_pipeline, max_tokens):
|
121 |
text = preprocess_text(text)
|
122 |
chunks = split_text_by_token_limit(text, tokenizer, max_tokens)
|
|
|
134 |
final_score = scores[final_label]
|
135 |
return {"label": final_label, "score": final_score}
|
136 |
|
137 |
+
|
138 |
+
|
139 |
def preprocess_text(text):
|
140 |
# Replace URLs and user mentions
|
141 |
text = re.sub(r'http\S+', 'http', text)
|
142 |
text = re.sub(r'@\w+', '@user', text)
|
143 |
+
return text
|
144 |
+
|
145 |
+
|
146 |
+
# def keyword_extraction(text):
|
147 |
+
# try:
|
148 |
+
# extractor = keyword_extractor()
|
149 |
+
# result = extractor(text)
|
150 |
+
# except Exception as e:
|
151 |
+
# # Optionally, log the error: print(f"Error processing text: {e}")
|
152 |
+
# result = None
|
153 |
+
|
154 |
+
# return result
|