Reyad-Ahmmed commited on
Commit
681469c
·
verified ·
1 Parent(s): 700c521

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -0
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
2
+ from datasets import load_dataset
3
+
4
+ # Load the dataset - Here we use the wiki_dpr dataset for retrieval
5
+ dataset = load_dataset('wiki_dpr', 'psgs_w100.nq.exact')
6
+
7
+ # Initialize the RAG tokenizer (use the T5 tokenizer for RAG)
8
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
9
+
10
+ # Initialize the RAG Retriever with the correct index name for wiki_dpr dataset
11
+ retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="compressed", use_dummy_dataset=True)
12
+
13
+ # Initialize the RAG Sequence Model (T5-based)
14
+ model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
15
+
16
+ # Tokenize a sample from the dataset (using wiki_dpr for retrieval)
17
+ sample = dataset["train"][0] # or dataset["validation"][0]
18
+ input_text = sample["query"]
19
+ context_text = sample["passage"]
20
+
21
+ # Tokenize the input question
22
+ inputs = tokenizer(input_text, return_tensors="pt")
23
+
24
+ # Generate the answer using the RAG model
25
+ outputs = model.generate(input_ids=inputs['input_ids'],
26
+ decoder_start_token_id=model.config.pad_token_id,
27
+ num_beams=3,
28
+ num_return_sequences=1,
29
+ do_sample=False)
30
+
31
+ # Decode the generated output
32
+ generated_answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
33
+
34
+ print(f"Question: {input_text}")
35
+ print(f"Answer: {generated_answer}")