Ramon Meffert
Tweak reader parameters
112c207
raw
history blame
1.04 kB
from transformers import DPRReader, DPRReaderTokenizer
from typing import List, Dict, Tuple
class DprReader():
def __init__(self) -> None:
self._tokenizer = DPRReaderTokenizer.from_pretrained(
"facebook/dpr-reader-single-nq-base")
self._model = DPRReader.from_pretrained(
"facebook/dpr-reader-single-nq-base"
)
def read(self,
query: str,
context: Dict[str, List[str]],
num_answers=5) -> List[Tuple]:
encoded_inputs = self._tokenizer(
questions=query,
titles=context['titles'],
texts=context['texts'],
return_tensors='pt',
truncation=True,
padding=True
)
outputs = self._model(**encoded_inputs)
predicted_spans = self._tokenizer.decode_best_spans(
encoded_inputs,
outputs,
num_spans=num_answers,
num_spans_per_passage=2,
max_answer_length=512)
return predicted_spans