Vera-ZWY's picture
Update app.py
366588b verified
import gradio as gr
from gradio_client import Client, handle_file
import seaborn as sns
import matplotlib.pyplot as plt
import os
import pandas as pd
from io import StringIO
# Define your Hugging Face token (make sure to set it as an environment variable)
HF_TOKEN = os.getenv("HF_TOKEN") # Replace with your actual token if not using an environment variable
# Initialize the Gradio Client for the specified API
client = Client("mangoesai/Elections_Comparison_Agent_V4", hf_token=HF_TOKEN)
# client_name = ['2016 Election','2024 Election', 'Comparison two years']
def stream_chat_with_rag(
message: str,
# history: list,
client_name: str
):
# print(f"Message: {message}")
#answer = client.predict(question=question, api_name="/run_graph")
answer, fig = client.predict(
query= message,
election_year=client_name,
api_name="/process_query"
)
# Debugging: Print the raw response
print("Raw answer from API:")
print(answer)
print("top works from API:")
print(fig)
# return answer, fig
return answer
def heatmap(top_n):
# df = pd.read_csv('submission_emotiontopics2024GPTresult.csv')
# topics_df = gr.Dataframe(value=df, label="Data Input")
pivot_table = client.predict(
top_n= top_n,
api_name="/get_heatmap_pivot_table"
)
print(pivot_table)
print(type(pivot_table))
"""
pivot_table is a dict like:
{'headers': ['Index', 'economy', 'human rights', 'immigrant', 'politics'],
'data': [['anger', 55880.0, 557679.0, 147766.0, 180094.0],
['disgust', 26911.0, 123112.0, 64567.0, 46460.0],
['fear', 51466.0, 188898.0, 113174.0, 150578.0],
['neutral', 77005.0, 192945.0, 20549.0, 190793.0]],
'metadata': None}
"""
# transfere dictionary to df
df = pd.DataFrame(pivot_table['data'], columns=pivot_table['headers'])
df.set_index('Index', inplace=True)
plt.figure(figsize=(10, 8))
sns.heatmap(df,
cmap='YlOrRd',
cbar_kws={'label': 'Weighted Frequency'},
square=True)
plt.title(f'Top {top_n} Emotions vs Topics Weighted Frequency')
plt.xlabel('Topics')
plt.ylabel('Emotions')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
return plt.gcf()
# Create Gradio interface
with gr.Blocks(title="Reddit Election Analysis") as demo:
gr.Markdown("# Reddit Public sentiment & Social topic distribution ")
with gr.Row():
with gr.Column():
with gr.Row():
top_n = gr.Dropdown(choices=[1,2,3,4,5,6,7,8,9,10])
with gr.Row():
fresh_btn = gr.Button("Refresh Heatmap")
with gr.Column():
output_heatmap = gr.Plot(
label="Top Public sentiment & Social topic Heatmap",
container=True, # Ensures the plot is contained within its area
elem_classes="heatmap-plot" # Add a custom class for styling
)
gr.Markdown("# Reddit Election Posts/Comments Analysis")
gr.Markdown("Ask questions about election-related comments and posts")
with gr.Row():
with gr.Column():
year_selector = gr.Radio(
choices=["2016 Election", "2024 Election", "Comparison two years"],
label="Select Election Year",
value="2016 Election"
)
query_input = gr.Textbox(
label="Your Question",
placeholder="Ask about election comments or posts..."
)
submit_btn = gr.Button("Submit")
gr.Markdown("""
## Example Questions:
- Is there any comments don't like the election results
- Summarize the main discussions about voting process
- What are the common opinions about candidates?
""")
with gr.Column():
output_text = gr.Textbox(
label="Response",
lines=20
)
with gr.Row():
output_plot = gr.Plot(
label="Topic Distribution",
container=True, # Ensures the plot is contained within its area
elem_classes="topic-plot" # Add a custom class for styling
)
# Add custom CSS to ensure proper plot sizing
gr.HTML("""
<style>
.topic-plot {
min-height: 600px;
width: 100%;
margin: auto;
}
.heatmap-plot {
min-height: 400px;
width: 100%;
margin: auto;
}
</style>
""")
fresh_btn.click(
fn=heatmap,
inputs=top_n,
outputs=output_heatmap
)
# Update both outputs when submit is clicked
# submit_btn.click(
# fn=stream_chat_with_rag,
# inputs=[query_input, year_selector],
# outputs=[output_text, output_plot]
# )
submit_btn.click(
fn=stream_chat_with_rag,
inputs=[query_input, year_selector],
outputs=output_text
)
if __name__ == "__main__":
demo.launch(share=True)