AmadouDiaV commited on
Commit
a1ff351
·
1 Parent(s): b4d2fbd

Premier commit

Browse files
Files changed (4) hide show
  1. Requirements.txt +9 -0
  2. app.py +26 -0
  3. app_streamlit.py +112 -0
  4. config.py +16 -0
Requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ datasets
3
+ huggingface_hub
4
+ flask==2.2.5
5
+ requests
6
+ tensorflow
7
+ streamlit
8
+ tf-keras
9
+ pymupdf
app.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request
2
+ from transformers import pipeline
3
+ app = Flask(__name__)
4
+ question_answerer = None
5
+
6
+ @app.before_first_request
7
+ def load_pipeline():
8
+ global question_answerer
9
+ question_answerer = pipeline("question-answering", "cancerfarore/bert-base-uncased-CancerFarore-Model", framework="tf")
10
+
11
+ @app.route("/answer", methods=["POST"])
12
+ def answer():
13
+ global question_answerer
14
+ obj = request.get_json()
15
+ context = obj['context']
16
+ question = obj['prompt']
17
+ return {"reponse" : question_answerer(context=context, question=question)['answer'], "score" : question_answerer(context=context, question=question)['score']}
18
+
19
+ @app.route("/load_model", methods=["POST"])
20
+ def load_model():
21
+ global question_answerer
22
+ obj = request.get_json()
23
+ model_name = obj['model']
24
+ question_answerer = pipeline("question-answering", model_name, framework="tf")
25
+ return f"Model {model_name}", 200
26
+
app_streamlit.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import fitz
3
+ import requests
4
+ import shelve
5
+ from config import *
6
+
7
+
8
+ def load_chat_history():
9
+ with shelve.open("Message_history/ARDB") as db:
10
+ return db.get("messages", [])
11
+
12
+ def load_current_model():
13
+ with shelve.open("Message_history/ARDB") as db:
14
+ return db.get("current_model", "BERT")
15
+
16
+ def save_current_model(model_name):
17
+ with shelve.open("Message_history/ARDB") as db:
18
+ db["current_model"] = model_name
19
+
20
+
21
+ def save_chat_history(messages):
22
+ with shelve.open("Message_history/ARDB") as db:
23
+ db["messages"] = messages
24
+
25
+
26
+ if "messages" not in st.session_state:
27
+ st.session_state.messages = load_chat_history()
28
+
29
+ if "current_model" not in st.session_state:
30
+ st.session_state.current_model = load_current_model()
31
+ if "current_context" not in st.session_state:
32
+ st.session_state.current_context = ""
33
+
34
+ st.title(TITRE + " - " + st.session_state.current_model)
35
+
36
+ with st.sidebar:
37
+ st.title(body="Veuillez choisir un modèle")
38
+ option = st.selectbox(
39
+ label='Veuillez choisir un modèle',
40
+ index=MODEL_NAMES.index(st.session_state.current_model),
41
+ options=MODEL_NAMES)
42
+
43
+ st.write('Vous avez selectionné :', option)
44
+
45
+ if st.button("Changer le modèle"):
46
+ if option == st.session_state.current_model:
47
+ st.warning("🤖 Ce modèle est déjà chargé", icon="⚠️")
48
+ else:
49
+ req = requests.post(API_URL + "/load_model", json={'model': AVAILABLE_MODELS[option]})
50
+ if req.status_code == 200:
51
+ save_current_model(option)
52
+ st.session_state.current_model = load_current_model()
53
+ st.session_state.messages = []
54
+ save_chat_history([])
55
+ st.rerun()
56
+ else:
57
+ st.warning("🤖 Je ne peux pas charger le modèle veuillez ressayer", icon="⚠️")
58
+
59
+ st.title(body="Veuillez donner un contexte")
60
+ upload_file = st.file_uploader("Importer un fichier texte ou pdf pour charger le contexte", type=['txt','pdf'])
61
+ if st.button("Changer le contexte"):
62
+ if upload_file:
63
+ if upload_file.type == 'text/plain':
64
+ st.session_state.current_context = upload_file.getvalue().decode("utf-8")
65
+ elif upload_file.type == 'application/pdf':
66
+ document = fitz.open(stream=upload_file.read())
67
+ text = ''
68
+ for page in document:
69
+ page_text = page.get_text()
70
+ page_text = page_text.replace('\n', ' ')
71
+ text += page_text
72
+ st.session_state.current_context = text
73
+ st.rerun()
74
+ else:
75
+ st.warning("🤖 Il n'y a aucun contexte à charger", icon="⚠️")
76
+
77
+ context = st.sidebar.text_area(value=st.session_state.current_context, label = "Ou écrire le contexte dans ce champs", placeholder = "Veuillez écrire le contexte de la question ici", height=300)
78
+ if st.button("Effacer le contexte"):
79
+ st.session_state.current_context = ""
80
+ if st.button("Supprimer l'historique"):
81
+ st.session_state.messages = []
82
+ save_chat_history([])
83
+
84
+ # Display chat messages
85
+ for message in st.session_state.messages:
86
+ avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
87
+ with st.chat_message(message["role"], avatar=avatar):
88
+ st.markdown(message["prompt"])
89
+
90
+ prompt = st.chat_input("How can I help?")
91
+ if prompt:
92
+ if len(context) > 20:
93
+ st.session_state.current_context = context
94
+ promptBody = {"role": "user", "prompt": prompt, "context": context}
95
+ req = requests.post(API_URL + "/answer", json = promptBody)
96
+ if req.status_code == 200:
97
+ st.session_state.messages.append(promptBody)
98
+ with st.chat_message("user", avatar=USER_AVATAR):
99
+ st.markdown(prompt)
100
+ with st.chat_message("assistant", avatar=BOT_AVATAR):
101
+ message_placeholder = st.empty()
102
+ res = req.json()
103
+ response = res['reponse']
104
+ message_placeholder.markdown(response)
105
+ st.session_state.messages.append({"role": "assistant", "prompt": response})
106
+ else:
107
+ st.warning("🤖 Je ne peux pas répondre à la question car il y a eu un problème avec l'API veuillez reposer votre question", icon="⚠️")
108
+ else:
109
+ st.warning("🤖 Je ne peux pas répondre à la question car il n'y a pas contexte ou ce dernier est trop court. Veuillez importer un contexte ou le coller dans son champ", icon="⚠️")
110
+
111
+
112
+ save_chat_history(st.session_state.messages)
config.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ AVAILABLE_MODELS = {
2
+ "BERT" : 'cancerfarore/bert-base-uncased-CancerFarore-Model',
3
+ "ALBERT" : "cancerfarore/albert-base-v2-CancerFarore-Model",
4
+ "ROBERTA" : "cancerfarore/roberta-base-CancerFarore-Modela",
5
+ "PACIFISTA" : "cancerfarore/roberta-base-CancerFarore-Model"
6
+ }
7
+
8
+ MODEL_NAMES = []
9
+ for NAME in AVAILABLE_MODELS:
10
+ MODEL_NAMES.append(NAME)
11
+
12
+ TITRE = "Pacifista 🐻"
13
+
14
+ API_URL = 'http://127.0.0.1:5000'
15
+ USER_AVATAR = "👤"
16
+ BOT_AVATAR = "🤖"