tarrasyed19472007 commited on
Commit
42feee4
·
verified ·
1 Parent(s): 9ed2bab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -39
app.py CHANGED
@@ -1,52 +1,66 @@
1
  import streamlit as st
2
- import PyPDF2
3
- import numpy as np
4
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
 
5
  import faiss
 
 
 
 
 
 
6
 
7
- # Function to load the PDF file and extract text
8
- def load_pdf(file):
9
- reader = PyPDF2.PdfReader(file)
10
  text = ""
11
- for page in reader.pages:
12
- text += page.extract_text()
13
  return text
14
 
15
- # Initialize the RAG model
16
- def initialize_rag_model():
17
- # Load the tokenizer
18
- tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
19
- # Load the retriever
20
- retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="legacy", use_dummy_dataset=True)
21
- # Load the model
22
- model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
23
- return tokenizer, retriever, model
24
-
25
- # Function to generate answers
26
- def generate_answer(question, text, tokenizer, retriever, model):
27
- inputs = tokenizer([question], return_tensors="pt")
28
- # Here we are creating a dummy dataset from the PDF text
29
- input_ids = tokenizer(text, return_tensors="pt").input_ids
30
- # Get the retrieved documents
31
- doc_scores, retrieved_doc_indices = retriever(input_ids, return_tensors="pt").values()
32
- # Generate answer
33
- outputs = model.generate(inputs["input_ids"], doc_scores=doc_scores)
34
- answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
 
35
  return answer
36
 
37
- # Streamlit app layout
38
- st.title("PDF Question-Answering Chatbot")
39
- uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- if uploaded_file is not None:
42
- text = load_pdf(uploaded_file)
43
- st.text_area("Extracted Text", text, height=300)
44
 
45
- question = st.text_input("Ask a question about the content:")
46
- if st.button("Get Answer"):
47
- if question:
48
- try:
49
- tokenizer, retriever, model = initialize_rag_model()
50
- answer = ge
51
 
52
 
 
1
  import streamlit as st
2
+ import fitz # PyMuPDF
 
3
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
4
+ import numpy as np
5
  import faiss
6
+ import torch
7
+
8
+ # Load the RAG model components
9
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
10
+ retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq")
11
+ model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq")
12
 
13
+ # Function to extract text from PDF
14
+ def extract_text_from_pdf(pdf_file):
15
+ doc = fitz.open(pdf_file)
16
  text = ""
17
+ for page in doc:
18
+ text += page.get_text()
19
  return text
20
 
21
+ # Function to handle question answering
22
+ def answer_question(question, pdf_text):
23
+ # Tokenize the question
24
+ inputs = tokenizer(question, return_tensors="pt")
25
+
26
+ # Retrieve documents based on the PDF text
27
+ doc_embeds = retriever.get_document_embeddings(pdf_text)
28
+ retriever.set_retriever_doc_embeddings(doc_embeds)
29
+
30
+ # Get the top k documents for the question
31
+ k = 5
32
+ retrieved_docs = retriever(question, n_docs=k)
33
+
34
+ # Prepare the context for the model
35
+ context = retrieved_docs["document_texts"]
36
+ context = " ".join(context)
37
+
38
+ # Generate the answer
39
+ input_dict = tokenizer.prepare_seq2seq_batch(question, context, return_tensors="pt")
40
+ outputs = model.generate(**input_dict)
41
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
42
  return answer
43
 
44
+ # Streamlit app
45
+ st.title("PDF Question-Answer Chatbot")
46
+ st.write("Upload a PDF file and ask questions based on its content.")
47
+
48
+ # File uploader
49
+ pdf_file = st.file_uploader("Upload PDF", type=["pdf"])
50
+ if pdf_file is not None:
51
+ # Extract text from the PDF
52
+ pdf_text = extract_text_from_pdf(pdf_file)
53
+ st.success("PDF loaded successfully!")
54
+
55
+ # Question input
56
+ question = st.text_input("Ask a question:")
57
+
58
+ if question:
59
+ with st.spinner("Finding answer..."):
60
+ answer = answer_question(question, pdf_text)
61
+ st.write("### Answer:")
62
+ st.write(answer)
63
 
 
 
 
64
 
 
 
 
 
 
 
65
 
66