tarrasyed19472007 commited on
Commit
4751c0e
·
verified ·
1 Parent(s): be2bd3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -32
app.py CHANGED
@@ -1,53 +1,63 @@
1
  import streamlit as st
2
  import PyPDF2
3
- from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
 
 
 
4
 
5
- # Load PDF and extract text
6
  def load_pdf(uploaded_file):
7
  reader = PyPDF2.PdfReader(uploaded_file)
8
- text = ""
9
  for page in reader.pages:
10
- if page.extract_text(): # Ensure text extraction is valid
11
- text += page.extract_text() + "\n"
12
  return text
13
 
14
  # Initialize RAG model
15
  def initialize_rag_model():
16
- # Load the tokenizer and model
17
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
18
-
19
- # Use a dummy retriever for testing purposes
20
  retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", use_dummy_dataset=True)
21
-
22
- model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq")
23
  return tokenizer, retriever, model
24
 
25
- # Process user query
26
- def generate_answer(query, context, tokenizer, retriever, model):
27
- # Tokenize the input question
28
- inputs = tokenizer(query, return_tensors="pt")
 
 
29
 
30
- # Prepare inputs for the model with a dummy context
31
- inputs["context_input_ids"] = retriever(context, return_tensors="pt")["input_ids"]
32
-
33
- # Generate the answer
34
- outputs = model.generate(**inputs)
35
- answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)
36
- return answer[0]
37
 
38
- # Streamlit UI
39
- st.title("PDF Question-Answer Chatbot")
 
40
 
41
- uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"])
42
- if uploaded_file is not None:
 
 
43
  text = load_pdf(uploaded_file)
44
- st.write("PDF loaded successfully. You can now ask questions.")
45
-
46
- # Initialize the RAG model
47
  tokenizer, retriever, model = initialize_rag_model()
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- user_query = st.text_input("Ask a question about the PDF:")
50
- if user_query:
51
- answer = generate_answer(user_query, text, tokenizer, retriever, model)
52
- st.write(f"Answer: {answer}") # Display the answer
53
 
 
1
  import streamlit as st
2
  import PyPDF2
3
+ import os
4
+ from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
5
+ import faiss
6
+ import torch
7
 
8
+ # Function to load PDF and extract text
9
  def load_pdf(uploaded_file):
10
  reader = PyPDF2.PdfReader(uploaded_file)
11
+ text = ''
12
  for page in reader.pages:
13
+ text += page.extract_text()
 
14
  return text
15
 
16
  # Initialize RAG model
17
  def initialize_rag_model():
18
+ # Load tokenizer, retriever, and model
19
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
 
 
20
  retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", use_dummy_dataset=True)
21
+ model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
 
22
  return tokenizer, retriever, model
23
 
24
+ # Function to answer questions
25
+ def answer_question(question, context, tokenizer, model):
26
+ input_ids = tokenizer.encode(question, return_tensors='pt')
27
+ context_ids = tokenizer.encode(context, return_tensors='pt')
28
+ input_ids = input_ids.to(model.device)
29
+ context_ids = context_ids.to(model.device)
30
 
31
+ # Generate answer
32
+ with torch.no_grad():
33
+ outputs = model(input_ids=input_ids, context_input_ids=context_ids)
34
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
35
+ return answer
 
 
36
 
37
+ # Main Streamlit application
38
+ st.title("PDF Q&A Chatbot")
39
+ st.write("Upload a PDF file and ask questions about its content.")
40
 
41
+ # Upload PDF file
42
+ uploaded_file = st.file_uploader("Choose a PDF file", type="pdf")
43
+
44
+ if uploaded_file:
45
  text = load_pdf(uploaded_file)
46
+ st.write("PDF content loaded successfully.")
47
+
48
+ # Initialize model
49
  tokenizer, retriever, model = initialize_rag_model()
50
+
51
+ # Get user question
52
+ question = st.text_input("Enter your question:")
53
+
54
+ if st.button("Get Answer"):
55
+ if text:
56
+ # Call the answer_question function
57
+ answer = answer_question(question, text, tokenizer, model)
58
+ st.write("Answer:", answer)
59
+ else:
60
+ st.error("No text found in the PDF.")
61
+
62
 
 
 
 
 
63