Spaces:
Running
Running
File size: 2,935 Bytes
37c1830 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import os
import uuid
from typing import List, Dict, Optional
import pandas as pd
from autorag.deploy import GradioRunner
from autorag.deploy.api import RetrievedPassage
from autorag.nodes.generator.base import BaseGenerator
from autorag.utils import fetch_contents
empty_retrieved_passage = RetrievedPassage(
content="", doc_id="", filepath=None, file_page=None, start_idx=None, end_idx=None
)
class GradioStreamRunner(GradioRunner):
def __init__(self, config: Dict, project_dir: Optional[str] = None):
super().__init__(config, project_dir)
data_dir = os.path.join(project_dir, "data")
self.corpus_df = pd.read_parquet(
os.path.join(data_dir, "corpus.parquet"), engine="pyarrow"
)
def stream_run(self, query: str):
previous_result = pd.DataFrame(
{
"qid": str(uuid.uuid4()),
"query": [query],
"retrieval_gt": [[]],
"generation_gt": [""],
}
) # pseudo qa data for execution
for module_instance, module_param in zip(
self.module_instances, self.module_params
):
if not isinstance(module_instance, BaseGenerator):
new_result = module_instance.pure(
previous_result=previous_result, **module_param
)
duplicated_columns = previous_result.columns.intersection(
new_result.columns
)
drop_previous_result = previous_result.drop(
columns=duplicated_columns
)
previous_result = pd.concat(
[drop_previous_result, new_result], axis=1
)
else:
# retrieved_passages = self.extract_retrieve_passage(
# previous_result
# )
# yield "", retrieved_passages
# Start streaming of the result
assert len(previous_result) == 1
prompt: str = previous_result["prompts"].tolist()[0]
for delta in module_instance.stream(prompt=prompt, **module_param):
yield delta, [empty_retrieved_passage]
def extract_retrieve_passage(self, df: pd.DataFrame) -> List[RetrievedPassage]:
retrieved_ids: List[str] = df["retrieved_ids"].tolist()[0]
contents = fetch_contents(self.corpus_df, [retrieved_ids])[0]
if "path" in self.corpus_df.columns:
paths = fetch_contents(self.corpus_df, [retrieved_ids], column_name="path")[
0
]
else:
paths = [None] * len(retrieved_ids)
metadatas = fetch_contents(
self.corpus_df, [retrieved_ids], column_name="metadata"
)[0]
if "start_end_idx" in self.corpus_df.columns:
start_end_indices = fetch_contents(
self.corpus_df, [retrieved_ids], column_name="start_end_idx"
)[0]
else:
start_end_indices = [None] * len(retrieved_ids)
return list(
map(
lambda content, doc_id, path, metadata, start_end_idx: RetrievedPassage(
content=content,
doc_id=doc_id,
filepath=path,
file_page=metadata.get("page", None),
start_idx=start_end_idx[0] if start_end_idx else None,
end_idx=start_end_idx[1] if start_end_idx else None,
),
contents,
retrieved_ids,
paths,
metadatas,
start_end_indices,
)
)
|