tarrasyed19472007 commited on
Commit
08d9dca
·
verified ·
1 Parent(s): 88f1ffd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -15,7 +15,10 @@ def load_pdf(uploaded_file):
15
  def initialize_rag_model():
16
  # Load the tokenizer and model
17
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
18
- retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="legacy", use_dummy_dataset=True)
 
 
 
19
  model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq")
20
  return tokenizer, retriever, model
21
 
@@ -23,12 +26,9 @@ def initialize_rag_model():
23
  def generate_answer(query, context, tokenizer, retriever, model):
24
  # Tokenize the input question
25
  inputs = tokenizer(query, return_tensors="pt")
26
-
27
- # Generate context embeddings using retriever
28
- context_input_ids = retriever(context, return_tensors="pt")["input_ids"]
29
-
30
- # Prepare inputs for the model
31
- inputs["context_input_ids"] = context_input_ids
32
 
33
  # Generate the answer
34
  outputs = model.generate(**inputs)
 
15
  def initialize_rag_model():
16
  # Load the tokenizer and model
17
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
18
+
19
+ # Use a dummy retriever for testing purposes
20
+ retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", use_dummy_dataset=True)
21
+
22
  model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq")
23
  return tokenizer, retriever, model
24
 
 
26
  def generate_answer(query, context, tokenizer, retriever, model):
27
  # Tokenize the input question
28
  inputs = tokenizer(query, return_tensors="pt")
29
+
30
+ # Prepare inputs for the model with a dummy context
31
+ inputs["context_input_ids"] = retriever(context, return_tensors="pt")["input_ids"]
 
 
 
32
 
33
  # Generate the answer
34
  outputs = model.generate(**inputs)