Truong-Phuc Nguyen commited on
Commit
d0150b0
1 Parent(s): 08604d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -16
app.py CHANGED
@@ -2,27 +2,26 @@ import streamlit as st
2
  import torch
3
  from transformers import pipeline
4
 
5
- st.set_page_config(page_title="Vietnamese Legal Question Answering", page_icon="🧊", layout="wide", initial_sidebar_state="collapsed")
6
 
7
- @st.cache_data
8
- def load_model(model_path):
9
- device = 0 if torch.cuda.is_available() else -1
10
- question_answerer = pipeline("question-answering", model=model_path, device=device)
11
- return question_answerer
12
-
13
- def get_answer(model, context, question):
14
- return model(context=context, question=question, max_answer_len=512)
15
 
 
16
  if 'model' not in st.session_state:
17
- st.session_state.model = load_model(model_path='./model')
 
 
 
 
18
 
19
- st.markdown("<h1 style='text-align: center;'>Vietnamese Legal Question Answering</h1>", unsafe_allow_html=True)
20
 
21
- context = st.text_area(label='Vietnamese Legal Documents/context:', placeholder='Enter your Vietnamese legal document here...', height=300)
22
- question = st.text_area(label='Question about this Vietnamese Legal Documents:', placeholder='Enter your question about this Vietnamese Legal Documents here...', height=100)
23
 
24
- btn_answer = st.button(label='Answer')
 
25
 
26
  if btn_answer:
27
- answer = get_answer(model=st.session_state.model, context=context, question=question)
28
- st.success(f"{answer['answer']}")
 
 
2
  import torch
3
  from transformers import pipeline
4
 
5
+ st.set_page_config(page_title="Vietnamese Legal Question Answering", page_icon="🧊", layout="centered", initial_sidebar_state="collapsed")
6
 
 
 
 
 
 
 
 
 
7
 
8
+ device = 0 if torch.cuda.is_available() else -1
9
  if 'model' not in st.session_state:
10
+ print('Some errors occured!')
11
+ st.session_state.model = pipeline("question-answering", model='./models/vi-mrc-large/archive/model', device=device)
12
+
13
+ def get_answer(context, question):
14
+ return st.session_state.model(context=context, question=question, max_answer_len=512)
15
 
16
+ st.markdown("<h1 style='text-align: center;'>Hệ thống hỏi đáp trực tuyến cho văn bản pháp luật Việt Nam</h1>", unsafe_allow_html=True)
17
 
18
+ context = st.text_area(label='Nội dung văn bản pháp luật Việt Nam:', placeholder='Vui lòng nhập nội dung văn bản pháp luật Việt Nam tại đây...', height=400)
19
+ question = st.text_area(label='Câu hỏi liên quan đến nội dung văn bản pháp luật Việt Nam ở trên:', placeholder='Vui lòng nhập câu hỏi liên quan đến nội dung văn bản pháp luật Việt Nam ở trên:...', height=100)
20
 
21
+ col_1, col_2, col_3, col_4, col_5 = st.columns([1, 1, 1.3, 1, 1])
22
+ btn_answer = col_3.button(label='Giải đáp thắc mắc', use_container_width=True)
23
 
24
  if btn_answer:
25
+ with st.spinner("Vui lòng chờ..."):
26
+ answer = get_answer(context=context, question=question)
27
+ st.success(f"Câu trả lời: {answer['answer']}")