g13_DL_project / app.py
kusa04's picture
Update app.py
012fb93 verified
raw
history blame
8.09 kB
from collections import Counter
from concurrent.futures import ThreadPoolExecutor # palarell processing
import matplotlib.pyplot as plt
import pandas as pd
import praw # Reddit's API
import re # Regular expression module
import streamlit as st
import time
import numpy as np
from wordcloud import WordCloud
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
TokenClassificationPipeline
)
from transformers.pipelines import AggregationStrategy
from functions import (
scrape_reddit_data,
safe_sentiment,
analyze_detail,
preprocess_text
)
# ---------- Cached function for loading the model pipelines ----------
@st.cache_resource
def load_sentiment_pipeline(): # sentiment pipeline
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
model = AutoModelForSequenceClassification.from_pretrained(
"cardiffnlp/twitter-roberta-base-sentiment-latest",
use_auth_token=st.secrets["hugging_face_token"]
)
sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, device=0) # -1 to 0
max_tokens = tokenizer.model_max_length
if max_tokens > 10000:
max_tokens = 200
return sentiment_pipeline, tokenizer, max_tokens
@st.cache_resource
def load_summarize_pipeline(): # summarize_pipeline
summarize_pipeline = pipeline("summarization", model="Falconsai/text_summarization", device=0)
return summarize_pipeline
@st.cache_resource
def summarize_txt(summarize_pipeline, texts, length):
if "count" not in st.session_state:
st.session_state.count = 0
summary = summarize_pipeline(texts, max_length=10, num_return_sequences=1)
result = summary[0]["summary_text"]
# print("summarized...")
st.session_state.count += 1
st.write(f"Phase: {st.session_state.count / length}")
return result
# class for keyword extraction
@st.cache_resource
class KeyphraseExtractionPipeline(TokenClassificationPipeline):
def __init__(self, model, *args, **kwargs):
super().__init__(
model=AutoModelForTokenClassification.from_pretrained(model),
tokenizer=AutoTokenizer.from_pretrained(model),
*args,
**kwargs
)
def postprocess(self, all_outputs):
results = super().postprocess(
all_outputs=all_outputs,
aggregation_strategy=AggregationStrategy.SIMPLE,
)
return np.unique([result.get("word").strip() for result in results])
@st.cache_resource
def keyword_extractor():
model_name = "ml6team/keyphrase-extraction-kbir-inspec"
extractor = KeyphraseExtractionPipeline(model=model_name)
return extractor
st.title("Scraping & Analysis of Reddit")
# --- User Input ---
user_query = st.text_input("Enter search keyword:", value="Monster Hunter Wilds")
if user_query:
search_query = f'"{user_query}" OR "{user_query.replace(" ", "")}"'
else:
search_query = ""
st.write("Search Query:", search_query)
# Button to trigger scraping and summarize
if st.button("Scrape & Summarize"):
with st.spinner("Scraping..."):
# progress_bar = st.progress(0)
progress_text = st.empty()
total_limit = 5000 # Maximum number of submissions to check
df = scrape_reddit_data(search_query, total_limit)
length = len(df)
progress_text.text(f"Collected {length} valid posts.")
with st.spinner("Loading Summarizing Pipeline"):
summarize_pipeline = load_summarize_pipeline()
with st.spinner("Summarizing txt data..."):
df["Detail_Summary"] = df["Detail"].apply(lambda x: summarize_txt(summarize_pipeline, x, length) if x else None)
st.session_state["df"] = df
# button to trigger sentiment analysis
if st.button("Sentiment Analysis"):
df = st.session_state.get("df")
with st.spinner("Loading Sentiment Pipeline..."):
sentiment_pipeline, tokenizer, max_tokens = load_sentiment_pipeline()
st.write("Sentiment pipeline loaded...")
with st.spinner("Doing Sentiment Analysis..."):
# title is short, so dont havwe to use batch processing
df['Title_Sentiment'] = df['Title'].apply(lambda x: safe_sentiment(sentiment_pipeline, text=preprocess_text(x), length) if x else None)
df['Detail_Sentiment'] = df['Detail_Summary'].apply(lambda x: safe_sentiment(sentiment_pipeline, text=preprocess_text(x), length) if x else None)
# # palarell procsssing for each row of detail
# with ThreadPoolExecutor() as executor:
# detail_sentiments = list(executor.map(
# lambda x: analyze_detail(x, tokenizer, sentiment_pipeline, max_tokens) if x else None,
# df['Detail']
# ))
# df['detail_sentiment'] = detail_sentiments
df["Title_Sentiment_Label"] = df["Title_Sentiment"].apply(lambda x: x["label"] if x else None)
df["Title_Sentiment_Score"] = df["Title_Sentiment"].apply(lambda x: x["score"] if x else None)
df["Detail_Sentiment_Label"] = df["Detail_Sentiment"].apply(lambda x: x["label"] if x else None)
df["Detail_Sentiment_Score"] = df["Detail_Sentiment"].apply(lambda x: x["score"] if x else None)
df = df.drop(["Title_Sentiment", "Detail_Sentiment"], axis=1)
cols = ["Title", "Title_Sentiment_Label", "Title_Sentiment_Score",
"Detail", "Detail_Sentiment_Label", "Detail_Sentiment_Score", "Date"]
df = df[cols]
st.session_state["df"] = df
# Button to draw graphs
if st.button("Draw Graph"):
df = st.session_state.get("df")
if df is None or df.empty:
st.write("Please run 'Scrape and Sentiment Analysis' first.")
else:
# ------------------- Plot Title's Sentiment Score -------------------#
fig1, ax1 = plt.subplots(figsize=(10, 5))
# Filter and plot for each sentiment category
positive_title = df[df["Title_Sentiment_Label"].str.lower() == "positive"]
negative_title = df[df["Title_Sentiment_Label"].str.lower() == "negative"]
neutral_title = df[df["Title_Sentiment_Label"].str.lower() == "neutral"]
ax1.plot(positive_title["Date"], positive_title["Title_Sentiment_Score"],
marker="o", label="Title Positive", color="orange")
ax1.plot(negative_title["Date"], negative_title["Title_Sentiment_Score"],
marker="o", label="Title Negative", color="blue")
ax1.plot(neutral_title["Date"], neutral_title["Title_Sentiment_Score"],
marker="o", label="Title Neutral", color="yellowgreen")
ax1.set_title("Title Sentiment Scores Over Time")
ax1.set_xlabel("Time")
ax1.set_ylabel("Sentiment Score")
ax1.legend()
plt.xticks(rotation=45)
st.pyplot(fig1)
# ------------------- Plot Detail's Sentiment Score -------------------#
fig2, ax2 = plt.subplots(figsize=(10, 5))
positive_detail = df[df["Detail_Sentiment_Label"].str.lower() == "positive"]
negative_detail = df[df["Detail_Sentiment_Label"].str.lower() == "negative"]
neutral_detail = df[df["Detail_Sentiment_Label"].str.lower() == "neutral"]
ax2.plot(positive_detail["Date"], positive_detail["Detail_Sentiment_Score"],
marker="+", label="Detail Positive", color="darkorange")
ax2.plot(negative_detail["Date"], negative_detail["Detail_Sentiment_Score"],
marker="+", label="Detail Negative", color="navy")
ax2.plot(neutral_detail["Date"], neutral_detail["Detail_Sentiment_Score"],
marker="+", label="Detail Neutral", color="forestgreen")
ax2.set_title("Detail Sentiment Scores Over Time")
ax2.set_xlabel("Time")
ax2.set_ylabel("Sentiment Score")
ax2.legend()
plt.xticks(rotation=45)
st.pyplot(fig2)