Spaces:
Sleeping
Sleeping
AmadouDiaV
commited on
Commit
·
a1ff351
1
Parent(s):
b4d2fbd
Premier commit
Browse files- Requirements.txt +9 -0
- app.py +26 -0
- app_streamlit.py +112 -0
- config.py +16 -0
Requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
datasets
|
3 |
+
huggingface_hub
|
4 |
+
flask==2.2.5
|
5 |
+
requests
|
6 |
+
tensorflow
|
7 |
+
streamlit
|
8 |
+
tf-keras
|
9 |
+
pymupdf
|
app.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, request
|
2 |
+
from transformers import pipeline
|
3 |
+
app = Flask(__name__)
|
4 |
+
question_answerer = None
|
5 |
+
|
6 |
+
@app.before_first_request
|
7 |
+
def load_pipeline():
|
8 |
+
global question_answerer
|
9 |
+
question_answerer = pipeline("question-answering", "cancerfarore/bert-base-uncased-CancerFarore-Model", framework="tf")
|
10 |
+
|
11 |
+
@app.route("/answer", methods=["POST"])
|
12 |
+
def answer():
|
13 |
+
global question_answerer
|
14 |
+
obj = request.get_json()
|
15 |
+
context = obj['context']
|
16 |
+
question = obj['prompt']
|
17 |
+
return {"reponse" : question_answerer(context=context, question=question)['answer'], "score" : question_answerer(context=context, question=question)['score']}
|
18 |
+
|
19 |
+
@app.route("/load_model", methods=["POST"])
|
20 |
+
def load_model():
|
21 |
+
global question_answerer
|
22 |
+
obj = request.get_json()
|
23 |
+
model_name = obj['model']
|
24 |
+
question_answerer = pipeline("question-answering", model_name, framework="tf")
|
25 |
+
return f"Model {model_name}", 200
|
26 |
+
|
app_streamlit.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import fitz
|
3 |
+
import requests
|
4 |
+
import shelve
|
5 |
+
from config import *
|
6 |
+
|
7 |
+
|
8 |
+
def load_chat_history():
|
9 |
+
with shelve.open("Message_history/ARDB") as db:
|
10 |
+
return db.get("messages", [])
|
11 |
+
|
12 |
+
def load_current_model():
|
13 |
+
with shelve.open("Message_history/ARDB") as db:
|
14 |
+
return db.get("current_model", "BERT")
|
15 |
+
|
16 |
+
def save_current_model(model_name):
|
17 |
+
with shelve.open("Message_history/ARDB") as db:
|
18 |
+
db["current_model"] = model_name
|
19 |
+
|
20 |
+
|
21 |
+
def save_chat_history(messages):
|
22 |
+
with shelve.open("Message_history/ARDB") as db:
|
23 |
+
db["messages"] = messages
|
24 |
+
|
25 |
+
|
26 |
+
if "messages" not in st.session_state:
|
27 |
+
st.session_state.messages = load_chat_history()
|
28 |
+
|
29 |
+
if "current_model" not in st.session_state:
|
30 |
+
st.session_state.current_model = load_current_model()
|
31 |
+
if "current_context" not in st.session_state:
|
32 |
+
st.session_state.current_context = ""
|
33 |
+
|
34 |
+
st.title(TITRE + " - " + st.session_state.current_model)
|
35 |
+
|
36 |
+
with st.sidebar:
|
37 |
+
st.title(body="Veuillez choisir un modèle")
|
38 |
+
option = st.selectbox(
|
39 |
+
label='Veuillez choisir un modèle',
|
40 |
+
index=MODEL_NAMES.index(st.session_state.current_model),
|
41 |
+
options=MODEL_NAMES)
|
42 |
+
|
43 |
+
st.write('Vous avez selectionné :', option)
|
44 |
+
|
45 |
+
if st.button("Changer le modèle"):
|
46 |
+
if option == st.session_state.current_model:
|
47 |
+
st.warning("🤖 Ce modèle est déjà chargé", icon="⚠️")
|
48 |
+
else:
|
49 |
+
req = requests.post(API_URL + "/load_model", json={'model': AVAILABLE_MODELS[option]})
|
50 |
+
if req.status_code == 200:
|
51 |
+
save_current_model(option)
|
52 |
+
st.session_state.current_model = load_current_model()
|
53 |
+
st.session_state.messages = []
|
54 |
+
save_chat_history([])
|
55 |
+
st.rerun()
|
56 |
+
else:
|
57 |
+
st.warning("🤖 Je ne peux pas charger le modèle veuillez ressayer", icon="⚠️")
|
58 |
+
|
59 |
+
st.title(body="Veuillez donner un contexte")
|
60 |
+
upload_file = st.file_uploader("Importer un fichier texte ou pdf pour charger le contexte", type=['txt','pdf'])
|
61 |
+
if st.button("Changer le contexte"):
|
62 |
+
if upload_file:
|
63 |
+
if upload_file.type == 'text/plain':
|
64 |
+
st.session_state.current_context = upload_file.getvalue().decode("utf-8")
|
65 |
+
elif upload_file.type == 'application/pdf':
|
66 |
+
document = fitz.open(stream=upload_file.read())
|
67 |
+
text = ''
|
68 |
+
for page in document:
|
69 |
+
page_text = page.get_text()
|
70 |
+
page_text = page_text.replace('\n', ' ')
|
71 |
+
text += page_text
|
72 |
+
st.session_state.current_context = text
|
73 |
+
st.rerun()
|
74 |
+
else:
|
75 |
+
st.warning("🤖 Il n'y a aucun contexte à charger", icon="⚠️")
|
76 |
+
|
77 |
+
context = st.sidebar.text_area(value=st.session_state.current_context, label = "Ou écrire le contexte dans ce champs", placeholder = "Veuillez écrire le contexte de la question ici", height=300)
|
78 |
+
if st.button("Effacer le contexte"):
|
79 |
+
st.session_state.current_context = ""
|
80 |
+
if st.button("Supprimer l'historique"):
|
81 |
+
st.session_state.messages = []
|
82 |
+
save_chat_history([])
|
83 |
+
|
84 |
+
# Display chat messages
|
85 |
+
for message in st.session_state.messages:
|
86 |
+
avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
|
87 |
+
with st.chat_message(message["role"], avatar=avatar):
|
88 |
+
st.markdown(message["prompt"])
|
89 |
+
|
90 |
+
prompt = st.chat_input("How can I help?")
|
91 |
+
if prompt:
|
92 |
+
if len(context) > 20:
|
93 |
+
st.session_state.current_context = context
|
94 |
+
promptBody = {"role": "user", "prompt": prompt, "context": context}
|
95 |
+
req = requests.post(API_URL + "/answer", json = promptBody)
|
96 |
+
if req.status_code == 200:
|
97 |
+
st.session_state.messages.append(promptBody)
|
98 |
+
with st.chat_message("user", avatar=USER_AVATAR):
|
99 |
+
st.markdown(prompt)
|
100 |
+
with st.chat_message("assistant", avatar=BOT_AVATAR):
|
101 |
+
message_placeholder = st.empty()
|
102 |
+
res = req.json()
|
103 |
+
response = res['reponse']
|
104 |
+
message_placeholder.markdown(response)
|
105 |
+
st.session_state.messages.append({"role": "assistant", "prompt": response})
|
106 |
+
else:
|
107 |
+
st.warning("🤖 Je ne peux pas répondre à la question car il y a eu un problème avec l'API veuillez reposer votre question", icon="⚠️")
|
108 |
+
else:
|
109 |
+
st.warning("🤖 Je ne peux pas répondre à la question car il n'y a pas contexte ou ce dernier est trop court. Veuillez importer un contexte ou le coller dans son champ", icon="⚠️")
|
110 |
+
|
111 |
+
|
112 |
+
save_chat_history(st.session_state.messages)
|
config.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AVAILABLE_MODELS = {
|
2 |
+
"BERT" : 'cancerfarore/bert-base-uncased-CancerFarore-Model',
|
3 |
+
"ALBERT" : "cancerfarore/albert-base-v2-CancerFarore-Model",
|
4 |
+
"ROBERTA" : "cancerfarore/roberta-base-CancerFarore-Modela",
|
5 |
+
"PACIFISTA" : "cancerfarore/roberta-base-CancerFarore-Model"
|
6 |
+
}
|
7 |
+
|
8 |
+
MODEL_NAMES = []
|
9 |
+
for NAME in AVAILABLE_MODELS:
|
10 |
+
MODEL_NAMES.append(NAME)
|
11 |
+
|
12 |
+
TITRE = "Pacifista 🐻"
|
13 |
+
|
14 |
+
API_URL = 'http://127.0.0.1:5000'
|
15 |
+
USER_AVATAR = "👤"
|
16 |
+
BOT_AVATAR = "🤖"
|