PACIFISTA / app_streamlit.py
AmadouDiaV
Premier commit
a1ff351
import streamlit as st
import fitz
import requests
import shelve
from config import *
def load_chat_history():
with shelve.open("Message_history/ARDB") as db:
return db.get("messages", [])
def load_current_model():
with shelve.open("Message_history/ARDB") as db:
return db.get("current_model", "BERT")
def save_current_model(model_name):
with shelve.open("Message_history/ARDB") as db:
db["current_model"] = model_name
def save_chat_history(messages):
with shelve.open("Message_history/ARDB") as db:
db["messages"] = messages
if "messages" not in st.session_state:
st.session_state.messages = load_chat_history()
if "current_model" not in st.session_state:
st.session_state.current_model = load_current_model()
if "current_context" not in st.session_state:
st.session_state.current_context = ""
st.title(TITRE + " - " + st.session_state.current_model)
with st.sidebar:
st.title(body="Veuillez choisir un modèle")
option = st.selectbox(
label='Veuillez choisir un modèle',
index=MODEL_NAMES.index(st.session_state.current_model),
options=MODEL_NAMES)
st.write('Vous avez selectionné :', option)
if st.button("Changer le modèle"):
if option == st.session_state.current_model:
st.warning("🤖 Ce modèle est déjà chargé", icon="⚠️")
else:
req = requests.post(API_URL + "/load_model", json={'model': AVAILABLE_MODELS[option]})
if req.status_code == 200:
save_current_model(option)
st.session_state.current_model = load_current_model()
st.session_state.messages = []
save_chat_history([])
st.rerun()
else:
st.warning("🤖 Je ne peux pas charger le modèle veuillez ressayer", icon="⚠️")
st.title(body="Veuillez donner un contexte")
upload_file = st.file_uploader("Importer un fichier texte ou pdf pour charger le contexte", type=['txt','pdf'])
if st.button("Changer le contexte"):
if upload_file:
if upload_file.type == 'text/plain':
st.session_state.current_context = upload_file.getvalue().decode("utf-8")
elif upload_file.type == 'application/pdf':
document = fitz.open(stream=upload_file.read())
text = ''
for page in document:
page_text = page.get_text()
page_text = page_text.replace('\n', ' ')
text += page_text
st.session_state.current_context = text
st.rerun()
else:
st.warning("🤖 Il n'y a aucun contexte à charger", icon="⚠️")
context = st.sidebar.text_area(value=st.session_state.current_context, label = "Ou écrire le contexte dans ce champs", placeholder = "Veuillez écrire le contexte de la question ici", height=300)
if st.button("Effacer le contexte"):
st.session_state.current_context = ""
if st.button("Supprimer l'historique"):
st.session_state.messages = []
save_chat_history([])
# Display chat messages
for message in st.session_state.messages:
avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
with st.chat_message(message["role"], avatar=avatar):
st.markdown(message["prompt"])
prompt = st.chat_input("How can I help?")
if prompt:
if len(context) > 20:
st.session_state.current_context = context
promptBody = {"role": "user", "prompt": prompt, "context": context}
req = requests.post(API_URL + "/answer", json = promptBody)
if req.status_code == 200:
st.session_state.messages.append(promptBody)
with st.chat_message("user", avatar=USER_AVATAR):
st.markdown(prompt)
with st.chat_message("assistant", avatar=BOT_AVATAR):
message_placeholder = st.empty()
res = req.json()
response = res['reponse']
message_placeholder.markdown(response)
st.session_state.messages.append({"role": "assistant", "prompt": response})
else:
st.warning("🤖 Je ne peux pas répondre à la question car il y a eu un problème avec l'API veuillez reposer votre question", icon="⚠️")
else:
st.warning("🤖 Je ne peux pas répondre à la question car il n'y a pas contexte ou ce dernier est trop court. Veuillez importer un contexte ou le coller dans son champ", icon="⚠️")
save_chat_history(st.session_state.messages)