|
import streamlit as st |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
from wordcloud import WordCloud |
|
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
|
|
st.set_page_config(page_title="π° News Classifier & Q&A App", layout="wide") |
|
|
|
|
|
@st.cache_resource |
|
def load_text_classifier(): |
|
model_name = "MihanTilk/News_Classifier" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
model_name |
|
) |
|
return pipeline("text-classification", model=model, tokenizer=tokenizer) |
|
|
|
|
|
classifier = load_text_classifier() |
|
qa_pipeline = pipeline( |
|
"question-answering", |
|
model="deepset/roberta-large-squad2", |
|
tokenizer="deepset/roberta-large-squad2" |
|
) |
|
|
|
|
|
st.markdown( |
|
""" |
|
<style> |
|
/* Main background and text colors */ |
|
.main { |
|
background-color: #f4f4f4; |
|
} |
|
|
|
/* Text input boxes - light blue theme */ |
|
.stTextInput>div>div>input, |
|
.stTextArea>div>div>textarea { |
|
background-color: #e6f2ff; |
|
border: 1px solid #b3d1ff; |
|
border-radius: 8px; |
|
color: #003366; |
|
} |
|
|
|
/* File uploader - matching style */ |
|
.stFileUploader>div>div { |
|
background-color: #e6f2ff; |
|
border: 1px solid #b3d1ff; |
|
border-radius: 8px; |
|
} |
|
|
|
/* Buttons - keeping your original style */ |
|
.stButton>button { |
|
background-color: #ff4b4b; |
|
color: white; |
|
border-radius: 10px; |
|
border: none; |
|
} |
|
|
|
.stDownloadButton>button { |
|
background-color: #4CAF50; |
|
color: white; |
|
border-radius: 10px; |
|
border: none; |
|
} |
|
|
|
/* Text colors */ |
|
h1, h2, h3, h4, h5, h6 { |
|
color: #003366; /* Dark blue for headers */ |
|
} |
|
|
|
p { |
|
color: #336699; /* Medium blue for paragraphs */ |
|
} |
|
|
|
/* Dataframe styling */ |
|
.dataframe { |
|
background-color: #e6f2ff; |
|
border: 1px solid #b3d1ff; |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
|
|
|
|
st.title("π° News Classification & Q&A App") |
|
st.markdown("<h4 style='color:#ff4b4b;'>Upload a CSV to classify news headlines and ask questions!</h4>", unsafe_allow_html=True) |
|
|
|
|
|
st.subheader("π Upload a CSV File") |
|
uploaded_file = st.file_uploader("Choose a CSV file...", type=["csv"]) |
|
|
|
if uploaded_file: |
|
|
|
df = pd.read_csv(uploaded_file, encoding='utf-8') |
|
if "content" not in df.columns: |
|
st.error("β The uploaded CSV must contain a 'content' column.") |
|
st.stop() |
|
|
|
|
|
df['cleaned_text'] = df['content'].astype(str).str.lower().str.strip() |
|
st.write("π Preview of Uploaded Data:", df.head()) |
|
|
|
|
|
with st.spinner("π Classifying news articles..."): |
|
df['class'] = df['cleaned_text'].apply(lambda text: classifier(text)[0]['label']) |
|
|
|
st.success("β
Classification Complete!") |
|
st.write("π Classified Results:", df[['content', 'class']].head()) |
|
|
|
|
|
st.subheader("π₯ Download Results") |
|
output_df = df[['content', 'class']] |
|
csv_output = output_df.to_csv(index=False, encoding='utf-8-sig').encode('utf-8-sig') |
|
st.download_button("Download Output CSV", data=csv_output, file_name="output.csv", mime="text/csv") |
|
|
|
|
|
st.subheader("π¬ Ask a Question") |
|
question = st.text_input("π What do you want to know about the content?") |
|
|
|
if st.button("Get Answer"): |
|
context = " ".join(df['content'].tolist()) |
|
with st.spinner("Answering..."): |
|
result = qa_pipeline(question=question, context=context) |
|
st.success(f"π Answer: {result['answer']}") |
|
|
|
|
|
|
|
st.subheader("π Data Visualizations") |
|
|
|
|
|
main_col1, main_col2 = st.columns([3, 2]) |
|
|
|
with main_col1: |
|
|
|
st.markdown("*Topic Distribution*") |
|
|
|
|
|
chart_col1, chart_col2 = st.columns(2) |
|
|
|
with chart_col1: |
|
|
|
fig1, ax1 = plt.subplots(figsize=(4, 4)) |
|
df['class'].value_counts().plot.pie( |
|
autopct='%1.1f%%', |
|
startangle=90, |
|
ax=ax1, |
|
colors=['#ff9999', '#66b3ff', '#99ff99', '#ffcc99', '#c2c2f0'], |
|
wedgeprops={'linewidth': 0.5, 'edgecolor': 'white'} |
|
) |
|
ax1.set_ylabel('') |
|
st.pyplot(fig1, use_container_width=True) |
|
|
|
with chart_col2: |
|
|
|
fig2, ax2 = plt.subplots(figsize=(4, 4)) |
|
df['class'].value_counts().plot.bar( |
|
color=['#ff9999', '#66b3ff', '#99ff99', '#ffcc99', '#c2c2f0'], |
|
ax=ax2, |
|
width=0.7 |
|
) |
|
ax2.set_xlabel('') |
|
ax2.set_ylabel('Count') |
|
plt.xticks(rotation=45, ha='right') |
|
st.pyplot(fig2, use_container_width=True) |
|
|
|
with main_col2: |
|
|
|
st.markdown("*Word Cloud*") |
|
text = " ".join(df['content'].tolist()) |
|
wordcloud = WordCloud( |
|
width=300, |
|
height=200, |
|
background_color="white", |
|
collocations=False, |
|
max_words=100 |
|
).generate(text) |
|
|
|
fig3, ax3 = plt.subplots(figsize=(4, 3)) |
|
ax3.imshow(wordcloud, interpolation="bilinear") |
|
ax3.axis("off") |
|
st.pyplot(fig3, use_container_width=True) |
|
|
|
|
|
with st.expander("π Detailed Statistics", expanded=False): |
|
stats_col1, stats_col2 = st.columns(2) |
|
|
|
with stats_col1: |
|
st.write("*Category Breakdown:*") |
|
stats_df = df['class'].value_counts().reset_index() |
|
stats_df.columns = ['Category', 'Count'] |
|
stats_df['Percentage'] = (stats_df['Count'] / stats_df['Count'].sum() * 100).round(1) |
|
st.dataframe(stats_df, height=200) |
|
|
|
with stats_col2: |
|
if 'date' in df.columns: |
|
try: |
|
st.write("*Monthly Trends*") |
|
df['date'] = pd.to_datetime(df['date']) |
|
trends = df.groupby([df['date'].dt.to_period('M'), 'class']).size().unstack() |
|
st.line_chart(trends) |
|
except: |
|
st.warning("Date parsing failed") |
|
|
|
|
|
st.subheader("π Explore News by Category") |
|
|
|
|
|
categories = df['class'].unique() |
|
|
|
|
|
cols = st.columns(5) |
|
|
|
|
|
category_articles = {category: df[df['class'] == category] for category in categories} |
|
|
|
|
|
for i, category in enumerate(categories): |
|
with cols[i]: |
|
if st.button(category, key=f"btn_{category}"): |
|
|
|
with st.popover(f"π° Articles in {category}", use_container_width=True): |
|
st.markdown(f"### {category} Articles") |
|
articles = category_articles[category] |
|
|
|
|
|
for idx, row in articles.iterrows(): |
|
with st.expander(f"Article {idx + 1}: {row['content'][:50]}...", expanded=False): |
|
st.write(row['content']) |
|
st.caption(f"Classification confidence: {classifier(row['content'])[0]['score']:.2f}") |
|
|
|
|
|
st.markdown("---") |
|
st.markdown("<p style='text-align:center; color:#666;'>π Built with using Streamlit & Hugging Face</p>", unsafe_allow_html=True) |