tarrasyed19472007 commited on
Commit
cd7c688
·
verified ·
1 Parent(s): a0a804f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -2
app.py CHANGED
@@ -1,10 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Process user query
2
  def generate_answer(query, context, tokenizer, retriever, model):
3
  # Tokenize the input question
4
  inputs = tokenizer(query, return_tensors="pt")
5
- # Retrieve the relevant documents
 
6
  context_input_ids = retriever(context, return_tensors="pt")["input_ids"]
 
 
7
  inputs["context_input_ids"] = context_input_ids
 
 
8
  outputs = model.generate(**inputs)
9
  answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)
10
  return answer[0]
@@ -23,4 +49,5 @@ if uploaded_file is not None:
23
  user_query = st.text_input("Ask a question about the PDF:")
24
  if user_query:
25
  answer = generate_answer(user_query, text, tokenizer, retriever, model)
26
- st.write(f"Answer: {answer}") # Corrected line with closing parenthesis
 
 
1
+ import streamlit as st
2
+ import PyPDF2
3
+ from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
4
+
5
+ # Load PDF and extract text
6
+ def load_pdf(uploaded_file):
7
+ reader = PyPDF2.PdfReader(uploaded_file)
8
+ text = ""
9
+ for page in reader.pages:
10
+ if page.extract_text(): # Ensure text extraction is valid
11
+ text += page.extract_text() + "\n"
12
+ return text
13
+
14
+ # Initialize RAG model
15
+ def initialize_rag_model():
16
+ # Load the tokenizer and model
17
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
18
+ retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="legacy", use_dummy_dataset=True)
19
+ model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq")
20
+ return tokenizer, retriever, model
21
+
22
  # Process user query
23
  def generate_answer(query, context, tokenizer, retriever, model):
24
  # Tokenize the input question
25
  inputs = tokenizer(query, return_tensors="pt")
26
+
27
+ # Generate context embeddings using retriever
28
  context_input_ids = retriever(context, return_tensors="pt")["input_ids"]
29
+
30
+ # Prepare inputs for the model
31
  inputs["context_input_ids"] = context_input_ids
32
+
33
+ # Generate the answer
34
  outputs = model.generate(**inputs)
35
  answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)
36
  return answer[0]
 
49
  user_query = st.text_input("Ask a question about the PDF:")
50
  if user_query:
51
  answer = generate_answer(user_query, text, tokenizer, retriever, model)
52
+ st.write(f"Answer: {answer}") # Display the answer
53
+