hynky HF staff commited on
Commit
276d919
·
1 Parent(s): 638184c

⚡️ make it faster

Browse files
src/logic/data_fetching.py CHANGED
@@ -6,7 +6,7 @@ import tempfile
6
  from pathlib import Path
7
  from concurrent.futures import ThreadPoolExecutor
8
  from typing import List, Dict
9
- from datatrove.io import get_datafolder
10
  from datatrove.utils.stats import MetricStatsDict
11
  import gradio as gr
12
  import tenacity
@@ -17,11 +17,17 @@ def find_folders(base_folder: str, path: str) -> List[str]:
17
  base_folder_df = get_datafolder(base_folder)
18
  if not base_folder_df.exists(path):
19
  return []
20
- return [
 
 
 
 
 
 
21
  folder
22
- for folder,info in base_folder_df.find(path, maxdepth=1, withdirs=True, detail=True).items()
23
  if info["type"] == "directory" and not (folder.rstrip("/") == path.rstrip("/"))
24
- ]
25
 
26
  def fetch_datasets(base_folder: str, progress=gr.Progress()):
27
  datasets = sorted(progress.tqdm(find_folders(base_folder, "")))
@@ -111,7 +117,7 @@ def fetch_graph_data(
111
  progress=gr.Progress(),
112
  ):
113
  if len(datasets) <= 0 or not metric_name or not grouping:
114
- return None
115
 
116
  with ThreadPoolExecutor() as pool:
117
  data = list(
 
6
  from pathlib import Path
7
  from concurrent.futures import ThreadPoolExecutor
8
  from typing import List, Dict
9
+ from datatrove.io import get_datafolder, _get_true_fs
10
  from datatrove.utils.stats import MetricStatsDict
11
  import gradio as gr
12
  import tenacity
 
17
  base_folder_df = get_datafolder(base_folder)
18
  if not base_folder_df.exists(path):
19
  return []
20
+
21
+ from huggingface_hub import HfFileSystem
22
+ extra_options = {}
23
+ if isinstance(_get_true_fs(base_folder_df.fs), HfFileSystem):
24
+ extra_options["expand_info"] = False # speed up
25
+
26
+ return (
27
  folder
28
+ for folder,info in base_folder_df.find(path, maxdepth=1, withdirs=True, detail=True, **extra_options).items()
29
  if info["type"] == "directory" and not (folder.rstrip("/") == path.rstrip("/"))
30
+ )
31
 
32
  def fetch_datasets(base_folder: str, progress=gr.Progress()):
33
  datasets = sorted(progress.tqdm(find_folders(base_folder, "")))
 
117
  progress=gr.Progress(),
118
  ):
119
  if len(datasets) <= 0 or not metric_name or not grouping:
120
+ return None, None
121
 
122
  with ThreadPoolExecutor() as pool:
123
  data = list(
src/logic/data_processing.py CHANGED
@@ -1,4 +1,5 @@
1
  from datetime import datetime
 
2
  import json
3
  import re
4
  import heapq
@@ -13,30 +14,43 @@ from src.logic.graph_settings import Grouping
13
  PARTITION_OPTIONS = Literal["Top", "Bottom", "Most frequent (n_docs)"]
14
 
15
  def prepare_for_non_grouped_plotting(metric: Dict[str, MetricStatsDict], normalization: bool, rounding: int) -> Dict[float, float]:
16
- metrics_rounded = defaultdict(lambda: 0)
17
- for key, value in metric.items():
18
- metrics_rounded[round(float(key), rounding)] += value.total
 
 
 
 
 
19
  if normalization:
20
- normalizer = sum(metrics_rounded.values())
21
- metrics_rounded = {k: v / normalizer for k, v in metrics_rounded.items()}
22
- assert abs(sum(metrics_rounded.values()) - 1) < 0.01
23
- return metrics_rounded
24
 
25
  def prepare_for_group_plotting(metric: Dict[str, MetricStatsDict], top_k: int, direction: PARTITION_OPTIONS, regex: str | None, rounding: int) -> Tuple[List[str], List[float], List[float]]:
26
  regex_compiled = re.compile(regex) if regex else None
27
- metric = {key: value for key, value in metric.items() if not regex or regex_compiled.match(key)}
28
- means = {key: round(float(value.mean), rounding) for key, value in metric.items()}
 
 
 
 
 
 
29
  if direction == "Top":
30
- keys = heapq.nlargest(top_k, means, key=means.get)
31
  elif direction == "Most frequent (n_docs)":
32
- totals = {key: int(value.n) for key, value in metric.items()}
33
- keys = heapq.nlargest(top_k, totals, key=totals.get)
34
  else:
35
- keys = heapq.nsmallest(top_k, means, key=means.get)
36
-
37
- means = [means[key] for key in keys]
38
- stds = [metric[key].standard_deviation for key in keys]
39
- return keys, means, stds
 
 
40
 
41
  def export_data(exported_data: Dict[str, MetricStatsDict], metric_name: str, grouping: Grouping):
42
  if not exported_data:
 
1
  from datetime import datetime
2
+ import numpy as np
3
  import json
4
  import re
5
  import heapq
 
14
  PARTITION_OPTIONS = Literal["Top", "Bottom", "Most frequent (n_docs)"]
15
 
16
  def prepare_for_non_grouped_plotting(metric: Dict[str, MetricStatsDict], normalization: bool, rounding: int) -> Dict[float, float]:
17
+ keys = np.array([float(key) for key in metric.keys()])
18
+ values = np.array([value.total for value in metric.values()])
19
+
20
+ rounded_keys = np.round(keys, rounding)
21
+ unique_keys, indices = np.unique(rounded_keys, return_inverse=True)
22
+ metrics_rounded = np.zeros_like(unique_keys, dtype=float)
23
+ np.add.at(metrics_rounded, indices, values)
24
+
25
  if normalization:
26
+ normalizer = np.sum(metrics_rounded)
27
+ metrics_rounded /= normalizer
28
+
29
+ return dict(zip(unique_keys, metrics_rounded))
30
 
31
  def prepare_for_group_plotting(metric: Dict[str, MetricStatsDict], top_k: int, direction: PARTITION_OPTIONS, regex: str | None, rounding: int) -> Tuple[List[str], List[float], List[float]]:
32
  regex_compiled = re.compile(regex) if regex else None
33
+ filtered_metric = {key: value for key, value in metric.items() if not regex or regex_compiled.match(key)}
34
+
35
+ keys = np.array(list(filtered_metric.keys()))
36
+ means = np.array([float(value.mean) for value in filtered_metric.values()])
37
+ stds = np.array([value.standard_deviation for value in filtered_metric.values()])
38
+
39
+ rounded_means = np.round(means, rounding)
40
+
41
  if direction == "Top":
42
+ top_indices = np.argsort(rounded_means)[-top_k:][::-1]
43
  elif direction == "Most frequent (n_docs)":
44
+ totals = np.array([int(value.n) for value in filtered_metric.values()])
45
+ top_indices = np.argsort(totals)[-top_k:][::-1]
46
  else:
47
+ top_indices = np.argsort(rounded_means)[:top_k]
48
+
49
+ top_keys = keys[top_indices]
50
+ top_means = rounded_means[top_indices]
51
+ top_stds = stds[top_indices]
52
+
53
+ return top_keys.tolist(), top_means.tolist(), top_stds.tolist()
54
 
55
  def export_data(exported_data: Dict[str, MetricStatsDict], metric_name: str, grouping: Grouping):
56
  if not exported_data:
src/logic/plotting.py CHANGED
@@ -11,7 +11,7 @@ from src.logic.utils import set_alpha
11
  from datatrove.utils.stats import MetricStatsDict
12
 
13
  def plot_scatter(
14
- data: Dict[str, Dict[float, float]],
15
  metric_name: str,
16
  log_scale_x: bool,
17
  log_scale_y: bool,
 
11
  from datatrove.utils.stats import MetricStatsDict
12
 
13
  def plot_scatter(
14
+ data: Dict[str, MetricStatsDict],
15
  metric_name: str,
16
  log_scale_x: bool,
17
  log_scale_y: bool,