Sciences-POC / app.py
ZakoST
debugging
f599167
import plotly.express as px
import pandas as pd
import json
import plotly.graph_objects as go
from datasets import load_dataset
import streamlit as st
REPO_ID = "libeIO/Sciences-POC"
with open('config/mapping_prompts.txt', 'r') as f:
mapping = json.loads(f.read())
with open('config/mapping_noms.txt', 'r') as f:
mapping_noms = json.loads(f.read())
if 'name' not in st.session_state.keys():
st.session_state['name'] = 'Groupe 1'
@st.cache_resource
def initialize(name):
articles = pd.read_csv('data/extract_sciences_po.csv')
with open(f"{mapping[mapping_noms[name]]['save_path']}", 'r') as f :
out_dict = json.loads(f.read())
df = pd.DataFrame.from_dict(out_dict)
articles = pd.merge(df, articles, on='item_id', how='left')
count_principale = df.groupby('categorie_principale').item_id.count()
print(f"Name : {name}\n Data : {df}")
df['categorie_secondaire'] = df.apply(lambda x : x.categorie_secondaire.split(',')[0] if x.categorie_secondaire!=None else None, axis=1)
count_secondaire = df.groupby('categorie_secondaire').item_id.count()
display_principale = count_principale.reset_index()
display_principale.columns = ['Catégorie', 'Nombre d\'articles']
display_secondaire = count_secondaire.reset_index()
display_secondaire.columns = ['Catégorie', 'Nombre d\'articles']
template ="ggplot2"
fig = go.Figure()
fig.update_layout(template=template,
)
fig.add_trace(go.Scatterpolar(
r=display_principale['Nombre d\'articles'],
theta=display_principale['Catégorie'],
fill='toself',
name='Catégorie Principale',
marker = {'color' : 'red'},
))
fig.add_trace(go.Scatterpolar(
r=display_secondaire['Nombre d\'articles'],
theta=display_secondaire['Catégorie'],
fill='toself',
name='Catégorie Secondaire',
marker = {'color' : 'blue'},
opacity=0.25,
))
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, max(max(display_principale['Nombre d\'articles']), max(display_secondaire['Nombre d\'articles']))]
)),
showlegend=True
)
fig.update_layout(legend=dict(
yanchor="top",
y=0.0001,
xanchor="left",
x=0.395
))
path_prompt = mapping[mapping_noms[name]]['path_prompt']
model = mapping[mapping_noms[name]]['client']
with open(path_prompt, 'r') as f :
prompt = f.read()
return fig, display_principale, articles, prompt, model
def display_article(article):
url = article['url']
colImage, colText = st.columns(2)
# try :
with colImage :
st.image(article["image_url"]) # image URL
with colText:
if 'subhead' in article.index and article['subhead']!='nan':
st.subheader(f":red[{article['subhead']}] [{article['titre'].rstrip('Libération').rstrip('-')[:-2]}]({url})") # Title
else :
# st.toast(article.index)
titre_cleaned = article['titre'].removesuffix('Libération').rstrip('-').strip()
st.subheader(f"[{titre_cleaned}]({url})") # Title
st.write(f"{article['description']}") # Header
formatted_date = article["date_published"]
if article.premium:
st.markdown(
f"""
<span style='color:grey'>{formatted_date+" "} </span> <span style='color:#eeb54e'> abonnés</span>
""",
unsafe_allow_html=True
)
else :
st.markdown(
f"""
<span style='color:grey'>{formatted_date+" "} </span>
""",
unsafe_allow_html=True
)
st.badge(f"Catégories secondaires : {article['categorie_secondaire']}", icon=":material/info:", color="blue")
# except :
# st.toast(f'Error displaying article {article.item_id}')
# return
fig, display_principale, articles, prompt, model = initialize(st.session_state['name'])
# col1, col2, col3 = st.columns([0.5, 0.2, 0.3])
st.selectbox("Choisir groupe", [mapping[k]['auteurs'] for k in mapping.keys()], key='name')
with st.expander(f"Prompt for model : {model}") :
st.markdown(prompt)
st.subheader('Répartition des articles par catégorie')
# with col1:
col1, col2 = st.columns([0.6, 0.4], vertical_alignment='center')
with col1:
st.plotly_chart(fig)
with col2:
st.dataframe(display_principale.set_index('Catégorie').sort_values(by='Nombre d\'articles', ascending=False))
st.subheader('Exemples d\'articles')
tabs = st.tabs(display_principale['Catégorie'].values.tolist())
for i in range(len(tabs)):
with tabs[i]:
cat = display_principale['Catégorie'][i]
for i, article in articles.loc[articles.categorie_principale==cat].sample(20, replace=True).drop_duplicates().iterrows():
display_article(article)