File size: 2,935 Bytes
37c1830
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import uuid
from typing import List, Dict, Optional

import pandas as pd
from autorag.deploy import GradioRunner
from autorag.deploy.api import RetrievedPassage
from autorag.nodes.generator.base import BaseGenerator
from autorag.utils import fetch_contents

empty_retrieved_passage = RetrievedPassage(
	content="", doc_id="", filepath=None, file_page=None, start_idx=None, end_idx=None
)

class GradioStreamRunner(GradioRunner):
	def __init__(self, config: Dict, project_dir: Optional[str] = None):
		super().__init__(config, project_dir)

		data_dir = os.path.join(project_dir, "data")
		self.corpus_df = pd.read_parquet(
			os.path.join(data_dir, "corpus.parquet"), engine="pyarrow"
		)

	def stream_run(self, query: str):
		previous_result = pd.DataFrame(
			{
				"qid": str(uuid.uuid4()),
				"query": [query],
				"retrieval_gt": [[]],
				"generation_gt": [""],
			}
		)  # pseudo qa data for execution

		for module_instance, module_param in zip(
				self.module_instances, self.module_params
		):
			if not isinstance(module_instance, BaseGenerator):
				new_result = module_instance.pure(
					previous_result=previous_result, **module_param
				)
				duplicated_columns = previous_result.columns.intersection(
					new_result.columns
				)
				drop_previous_result = previous_result.drop(
					columns=duplicated_columns
				)
				previous_result = pd.concat(
					[drop_previous_result, new_result], axis=1
				)
			else:
				# retrieved_passages = self.extract_retrieve_passage(
				# 	previous_result
				# )
				# yield "", retrieved_passages
				# Start streaming of the result
				assert len(previous_result) == 1
				prompt: str = previous_result["prompts"].tolist()[0]
				for delta in module_instance.stream(prompt=prompt, **module_param):
					yield delta, [empty_retrieved_passage]


	def extract_retrieve_passage(self, df: pd.DataFrame) -> List[RetrievedPassage]:
		retrieved_ids: List[str] = df["retrieved_ids"].tolist()[0]
		contents = fetch_contents(self.corpus_df, [retrieved_ids])[0]
		if "path" in self.corpus_df.columns:
			paths = fetch_contents(self.corpus_df, [retrieved_ids], column_name="path")[
				0
			]
		else:
			paths = [None] * len(retrieved_ids)
		metadatas = fetch_contents(
			self.corpus_df, [retrieved_ids], column_name="metadata"
		)[0]
		if "start_end_idx" in self.corpus_df.columns:
			start_end_indices = fetch_contents(
				self.corpus_df, [retrieved_ids], column_name="start_end_idx"
			)[0]
		else:
			start_end_indices = [None] * len(retrieved_ids)
		return list(
			map(
				lambda content, doc_id, path, metadata, start_end_idx: RetrievedPassage(
					content=content,
					doc_id=doc_id,
					filepath=path,
					file_page=metadata.get("page", None),
					start_idx=start_end_idx[0] if start_end_idx else None,
					end_idx=start_end_idx[1] if start_end_idx else None,
				),
				contents,
				retrieved_ids,
				paths,
				metadatas,
				start_end_indices,
			)
		)