Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -15,7 +15,10 @@ def load_pdf(uploaded_file):
|
|
15 |
def initialize_rag_model():
|
16 |
# Load the tokenizer and model
|
17 |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
|
18 |
-
|
|
|
|
|
|
|
19 |
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq")
|
20 |
return tokenizer, retriever, model
|
21 |
|
@@ -23,12 +26,9 @@ def initialize_rag_model():
|
|
23 |
def generate_answer(query, context, tokenizer, retriever, model):
|
24 |
# Tokenize the input question
|
25 |
inputs = tokenizer(query, return_tensors="pt")
|
26 |
-
|
27 |
-
#
|
28 |
-
context_input_ids = retriever(context, return_tensors="pt")["input_ids"]
|
29 |
-
|
30 |
-
# Prepare inputs for the model
|
31 |
-
inputs["context_input_ids"] = context_input_ids
|
32 |
|
33 |
# Generate the answer
|
34 |
outputs = model.generate(**inputs)
|
|
|
15 |
def initialize_rag_model():
|
16 |
# Load the tokenizer and model
|
17 |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
|
18 |
+
|
19 |
+
# Use a dummy retriever for testing purposes
|
20 |
+
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", use_dummy_dataset=True)
|
21 |
+
|
22 |
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq")
|
23 |
return tokenizer, retriever, model
|
24 |
|
|
|
26 |
def generate_answer(query, context, tokenizer, retriever, model):
|
27 |
# Tokenize the input question
|
28 |
inputs = tokenizer(query, return_tensors="pt")
|
29 |
+
|
30 |
+
# Prepare inputs for the model with a dummy context
|
31 |
+
inputs["context_input_ids"] = retriever(context, return_tensors="pt")["input_ids"]
|
|
|
|
|
|
|
32 |
|
33 |
# Generate the answer
|
34 |
outputs = model.generate(**inputs)
|