Spaces:
Sleeping
Sleeping
import ast | |
import glob | |
import time | |
from itertools import islice | |
from functools import partial | |
from textwrap import dedent | |
from typing import Optional, Type | |
import gradio as gr | |
import nltk | |
import pandas as pd | |
from datatrove.data import Document | |
from datatrove.executor.local import LocalPipelineExecutor | |
from datatrove.pipeline.extractors import Trafilatura | |
from datatrove.pipeline.filters.base_filter import BaseFilter | |
from datatrove.pipeline.filters import ( | |
C4QualityFilter, | |
FineWebQualityFilter, | |
GopherQualityFilter, | |
GopherRepetitionFilter, | |
LanguageFilter, | |
URLFilter, | |
) | |
from datatrove.pipeline.formatters import PIIFormatter | |
from datatrove.pipeline.readers import JsonlReader, WarcReader | |
from datatrove.utils.typeshelper import Languages | |
nltk.download('punkt_tab') | |
DUMP_TO_PROCESS = "CC-MAIN-2023-50" | |
TIMEOUT = 600 | |
steps = [ | |
URLFilter, | |
Trafilatura, | |
LanguageFilter, | |
GopherRepetitionFilter, | |
GopherQualityFilter, | |
C4QualityFilter, | |
FineWebQualityFilter, | |
PIIFormatter | |
] | |
DEFAULT_CODE = dedent( | |
""" | |
```python | |
from datatrove.executor.local import LocalPipelineExecutor | |
from datatrove.pipeline.extractors import Trafilatura | |
from datatrove.pipeline.filters import ( | |
C4QualityFilter, | |
FineWebQualityFilter, | |
GopherQualityFilter, | |
GopherRepetitionFilter, | |
LanguageFilter, | |
URLFilter, | |
) | |
from datatrove.pipeline.formatters import PIIFormatter | |
from datatrove.pipeline.readers import WarcReader | |
""" | |
).strip() + ( | |
"\n\n" | |
"pipeline_executor = LocalPipelineExecutor(\n" | |
" pipeline=[\n" | |
f' WarcReader("s3://commoncrawl/crawl-data/{DUMP_TO_PROCESS}/segments", glob_pattern="*/warc/*"),\n' | |
) + ",\n".join([ | |
" " + step.__name__ + "()" for step in steps | |
]) + ( | |
"\n" | |
" ]\n" | |
")" | |
) + dedent( | |
""" | |
pipeline_executor.run() | |
``` | |
""" | |
) | |
make_gallery_image_buttons_js = """ | |
function load() { | |
let buttons = document.getElementsByClassName("block-button"); | |
Array.from(document.getElementById("pipeline-gallery").getElementsByClassName("thumbnail-item")).map( | |
(b, i) => b.addEventListener("click", () => buttons[i].click()) | |
) | |
} | |
""" | |
css = """ | |
tr td { | |
border-top: 1px solid black; | |
} | |
.grid-container { | |
gap: 0; | |
grid-template-rows: auto; | |
grid-auto-rows: auto; | |
} | |
.thumbnail-item { | |
aspect-ratio: auto; | |
height: min-content; | |
} | |
.grid-wrap { | |
min-height: 0; | |
} | |
.table-wrap { | |
min-height: 600px; | |
max-height: 600px; | |
} | |
.scollabe_tabs .tab-wrapper .tab-container { | |
overflow: scroll; | |
} | |
""" | |
blocks = sorted(glob.glob("images/*.png")) | |
def prepare_as_list_or_none(text: str) -> Optional[list[str]]: | |
return ([x.strip() for x in text.split(",") if x.strip()] or None) if text else None | |
def non_empty_list_or_none(input_list: list[str]) -> Optional[list[str]]: | |
return input_list or None | |
with gr.Blocks(css=css, js=make_gallery_image_buttons_js) as demo: | |
state = gr.State({"selected_block": None}) | |
gr.Markdown("# Common Crawl Pipeline Creator") | |
with gr.Row(): | |
with gr.Column(min_width=640): | |
gallery = gr.Gallery( | |
blocks, | |
columns=4, | |
rows=2, | |
label="Select step to edit", | |
object_fit="scale-down", | |
show_share_button=False, | |
show_download_button=False, | |
show_fullscreen_button=False, | |
elem_id="pipeline-gallery", | |
allow_preview=False, | |
) | |
gallery_image_buttons = [gr.Button(visible=False, elem_classes="block-button") for _ in blocks] # hack to simulate each image galery as a button, see `make_gallery_image_buttons_js`` | |
view_pipeline_results_button = gr.Button("Run Pipeline & Stream Results", variant="primary", scale=4) | |
blocks_uis = [] | |
with gr.Column(visible=False) as col: | |
blocks_uis.append(col) | |
gr.Markdown("## 1. URL Filtering \n\nPerforms filtering based on samples urls.") | |
with gr.Group(): | |
url_filtering_checkbox = gr.Checkbox(True, label="Enable") | |
with gr.Accordion("Parameters", open=True) as acc: | |
use_integrated_lists_checkbox = gr.Checkbox(True, label="use_integrated_lists", info="use the datatrove integrated lists of banned urls and words") | |
with gr.Row(): | |
with gr.Column(): | |
extra_domain_textbox = gr.Textbox("", label="extra_domains", info="remove if the domain is present in `extra_domains`") | |
extra_domain_textbox.prepare_parameter = prepare_as_list_or_none | |
extra_urls_textbox = gr.Textbox("", label="extra_urls", info="remove if the full url is present on `extra_urls`") | |
extra_urls_textbox.prepare_parameter = prepare_as_list_or_none | |
with gr.Column(): | |
banned_words_textbox = gr.Textbox("", label="banned_words", info="remove if any word from `banned_words` is in the url") | |
banned_words_textbox.prepare_parameter = prepare_as_list_or_none | |
banned_subwords_textbox = gr.Textbox("", label="banned_subwords", info="remove if any word from `banned_subwords` is a substring of the url") | |
banned_subwords_textbox.prepare_parameter = prepare_as_list_or_none | |
with gr.Column(): | |
soft_banned_words_textbox = gr.Textbox("", label="soft_banned_words", info="remove if there are at least `soft_word_threshold` words from `soft_banned_words` in the url") | |
soft_banned_words_textbox.prepare_parameter = prepare_as_list_or_none | |
soft_word_threshold_slider = gr.Slider(0, 5, value=2, step=1, label="soft_word_threshold", info="remove if there are at least `soft_word_threshold` words from `soft_banned_words` in the url") | |
url_filtering_checkbox.change(lambda visible: gr.Accordion(visible=visible), inputs=url_filtering_checkbox, outputs=acc) | |
url_filtering_parameters_components = [use_integrated_lists_checkbox, extra_domain_textbox, extra_urls_textbox, banned_words_textbox, banned_subwords_textbox, soft_banned_words_textbox, soft_word_threshold_slider] | |
with gr.Column(visible=False) as col: | |
blocks_uis.append(col) | |
gr.Markdown("## 2. Text Extraction \n\nUses the [Trafilatura](https://trafilatura.readthedocs.io) extractor.") | |
with gr.Group(): | |
text_extraction_checkbox = gr.Checkbox(True, label="Enable") | |
with gr.Accordion("Parameters", open=True) as acc: | |
with gr.Row(): | |
favour_precision_checkbox = gr.Checkbox(True, label="favour_precision", info="prefer less text but correct extraction") | |
timeout_slider = gr.Slider(0.05, 0.5, value=0.1, step=0.05, label="timeout", info="the timeout for extraction, per document, in seconds") | |
deduplicate_checkbox = gr.Checkbox(True, label="deduplicate", info="trafilatura's deduplicate option") | |
text_extraction_checkbox.change(lambda visible: gr.Accordion(visible=visible), inputs=text_extraction_checkbox, outputs=acc) | |
text_extraction_parameters_components = [favour_precision_checkbox, timeout_slider, deduplicate_checkbox] | |
with gr.Column(visible=False) as col: | |
blocks_uis.append(col) | |
gr.Markdown("## 3. Language Filtering \n\nUses the [fastext](https://fasttext.cc/docs/en/language-identification.html) language identification models.") | |
with gr.Group(): | |
language_filtering_checkbox = gr.Checkbox(True, label="Enable") | |
with gr.Accordion("Parameters", open=True) as acc: | |
with gr.Row(): | |
languages_textbox = gr.Dropdown(sorted(v for k, v in vars(Languages).items() if not k.startswith("__")), multiselect=True, label="languages", info="list of languages to keep. empty for all") | |
languages_textbox.prepare_parameter = non_empty_list_or_none | |
language_threshold_slider = gr.Slider(0, 1, value=0.65, step=0.05, label="language_threshold", info="minimum score to accept a document") | |
language_filtering_checkbox.change(lambda visible: gr.Accordion(visible=visible), inputs=language_filtering_checkbox, outputs=acc) | |
language_filtering_parameters_components = [languages_textbox, language_threshold_slider] | |
with gr.Column(visible=False) as col: | |
blocks_uis.append(col) | |
gr.Markdown("## 4. Gopher Filtering (repetitions) \n\nUses the [Gopher](https://huggingface.co/papers/2112.11446) text repetition filters.") | |
with gr.Group(): | |
gopher_filtering_repetitions_checkbox = gr.Checkbox(True, label="Enable") | |
with gr.Accordion("Parameters", open=True) as acc: | |
with gr.Group(): | |
with gr.Row(): | |
language_dropdown1 = gr.Dropdown(sorted(v for k, v in vars(Languages).items() if not k.startswith("__")), value=Languages.english, label="language", info="tokenizer language") | |
top_n_grams_textbox = gr.Textbox("(2, 0.2), (3, 0.18), (4, 0.16)", label="top_n_grams") | |
top_n_grams_textbox.prepare_parameter = ast.literal_eval | |
dup_n_grams_textbox = gr.Textbox("(5, 0.15), (6, 0.14), (7, 0.13), (8, 0.12), (9, 0.11), (10, 0.10)", label="dup_n_grams") | |
dup_n_grams_textbox.prepare_parameter = ast.literal_eval | |
with gr.Row(): | |
dup_line_frac_slider = gr.Slider(0, 1, value=0.3, step=0.05, label="dup_line_frac") | |
dup_para_frac_slider = gr.Slider(0, 1, value=0.3, step=0.05, label="dup_para_frac") | |
dup_line_char_frac_slider = gr.Slider(0, 1, value=0.2, step=0.05, label="dup_line_char_frac") | |
dup_para_char_frac_slider = gr.Slider(0, 1, value=0.2, step=0.05, label="dup_para_char_frac") | |
gopher_filtering_repetitions_checkbox.change(lambda visible: gr.Accordion(visible=visible), inputs=gopher_filtering_repetitions_checkbox, outputs=acc) | |
gopher_filtering_repetitions_parameters_components = [language_dropdown1, top_n_grams_textbox, dup_n_grams_textbox, dup_line_frac_slider, dup_para_frac_slider, dup_line_char_frac_slider, dup_para_char_frac_slider] | |
with gr.Column(visible=False) as col: | |
blocks_uis.append(col) | |
gr.Markdown("## 8. PII Removal \n\nReplaces email addresses and ip addresses in the document text.") | |
with gr.Group(): | |
pii_removal_checkbox = gr.Checkbox(True, label="Enable") | |
with gr.Accordion("Parameters", open=True) as acc: | |
with gr.Row(): | |
remove_emails_checkbox = gr.Checkbox(True, label="remove_emails", info="Replace email addresses") | |
remove_ips_checkbox = gr.Checkbox(True, label="remove_ips", info="Replace IP addresses") | |
only_remove_public_ips_checkbox = gr.Checkbox(True, label="only_remove_public_ips", info="by default we only replace public (and thus PII) IPs") | |
with gr.Row(): | |
email_replacement_textbox = gr.Textbox("[email protected], [email protected]", label="email_replacement", info="strings to use as replacement. They will be used in a circular way") | |
email_replacement_textbox.prepare_parameter = prepare_as_list_or_none | |
ip_replacement_textbox = gr.Textbox("22.214.171.124, 126.96.36.199, 188.8.131.52, 184.108.40.206, 220.127.116.11, 18.104.22.168", label="ip_replacement", info="same as email_replacement but for IP addresses") | |
ip_replacement_textbox.prepare_parameter = prepare_as_list_or_none | |
pii_removal_checkbox.change(lambda visible: gr.Accordion(visible=visible), inputs=pii_removal_checkbox, outputs=acc) | |
pii_removal_parameters_components = [remove_emails_checkbox, remove_ips_checkbox, only_remove_public_ips_checkbox, email_replacement_textbox, ip_replacement_textbox] | |
with gr.Column(visible=False) as col: | |
blocks_uis.append(col) | |
gr.Markdown("## 7. Custom Filters \n\nUses the [FineWeb](https://huggingface.co/datasets/HuggingFaceFW/fineweb) custom text filters.") | |
with gr.Group(): | |
custom_filters_checkbox = gr.Checkbox(True, label="Enable") | |
with gr.Accordion("Parameters", open=True) as acc: | |
with gr.Row(): | |
line_punct_thr_slider = gr.Slider(0, 1, value=0.12, step=0.01, label="line_punct_thr") | |
line_punct_exclude_zero = gr.Checkbox(False, label="line_punct_exclude_zero") | |
short_line_thr_slider = gr.Slider(0, 1, value=0.67, step=0.01, label="short_line_thr") | |
short_line_length_slider = gr.Slider(0, 100, value=30, step=1, label="short_line_length") | |
char_duplicates_ratio_slider = gr.Slider(0, 1, value=0.01, step=0.01, label="char_duplicates_ratio") | |
new_line_ratio_slider = gr.Slider(0, 1, value=0.3, step=0.01, label="new_line_ratio") | |
custom_filters_checkbox.change(lambda visible: gr.Accordion(visible=visible), inputs=custom_filters_checkbox, outputs=acc) | |
custom_filters_parameters_components = [line_punct_thr_slider, line_punct_exclude_zero, short_line_thr_slider, short_line_length_slider, char_duplicates_ratio_slider, new_line_ratio_slider] | |
with gr.Column(visible=False) as col: | |
blocks_uis.append(col) | |
gr.Markdown("## 6. C4 Filters\n\nUses the [C4](https://huggingface.co/datasets/allenai/c4) text size and content filters.") | |
with gr.Group(): | |
c4_filters_checkbox = gr.Checkbox(True, label="Enable") | |
with gr.Accordion(" Parameters", open=True) as acc: | |
with gr.Group(): | |
with gr.Row(): | |
split_paragraph_checkbox = gr.Checkbox(True, label="split_paragraph", info="disable to apply the filters to each sentence instead of to each line") | |
with gr.Row(): | |
language_dropdown2 = gr.Dropdown(sorted(v for k, v in vars(Languages).items() if not k.startswith("__")), value=Languages.english, label="language", info="tokenizer language") | |
min_num_sentences_slider = gr.Slider(0, 10, value=5, step=1, label="min_num_sentences", info="remove documents that do not have at least this number of sentences (after line filtering)") | |
min_words_per_line_slider = gr.Slider(0, 10, value=3, step=1, label="min_words_per_line", info="drop lines without this min number of words") | |
max_word_length_slider = gr.Slider(0, 2000, value=1000, step=10, label="max_word_length", info=" drop lines where at least one word has more than this number of characters") | |
with gr.Row(): | |
remove_citations_checkbox = gr.Checkbox(True, label="remove_citations", info="remove wikipedia style citations from the text") | |
filter_no_terminal_punct_checkbox = gr.Checkbox(True, label="filter_no_terminal_punct", info="remove lines without terminal punctuation marks") | |
filter_lorem_ipsum_checkbox = gr.Checkbox(True, label="filter_lorem_ipsum", info="drop documents that contain 'lorem ipsum'") | |
filter_javascript_checkbox = gr.Checkbox(True, label="filter_javascript", info="drop lines mentioning 'javascript'") | |
filter_curly_bracket = gr.Checkbox(True, label="filter_curly_bracket", info="drop documents containing {") | |
filter_policy = gr.Checkbox(True, label="filter_policy", info="drop lines containing any of the policy phrases (e.g. 'terms of use', 'use cookies')") | |
c4_filters_checkbox.change(lambda visible: gr.Accordion(visible=visible), inputs=c4_filters_checkbox, outputs=acc) | |
c4_filters_parameters_components = [split_paragraph_checkbox, language_dropdown2, min_num_sentences_slider, min_words_per_line_slider, max_word_length_slider, remove_citations_checkbox, filter_no_terminal_punct_checkbox, filter_lorem_ipsum_checkbox, filter_javascript_checkbox, filter_curly_bracket, filter_policy] | |
with gr.Column(visible=False) as col: | |
blocks_uis.append(col) | |
gr.Markdown("## 5. Gopher Filtering (quality) \n\nUses the [Gopher](https://huggingface.co/papers/2112.11446) text quality filters.") | |
with gr.Group(): | |
gopher_filtering_quality_checkbox = gr.Checkbox(True, label="Enable") | |
with gr.Accordion("Parameters", open=True) as acc: | |
with gr.Group(): | |
with gr.Row(): | |
language_dropdown2 = gr.Dropdown(sorted(v for k, v in vars(Languages).items() if not k.startswith("__")), value=Languages.english, label="language", info="tokenizer language") | |
min_doc_words_slider = gr.Slider(0, 1000, value=50, step=10, label="min_doc_words") | |
max_doc_words_slider = gr.Slider(0, 200_000, value=100_000, step=10_000, label="max_doc_words") | |
with gr.Row(): | |
min_avg_word_length_slider = gr.Slider(0, 20, value=3, step=1, label="min_avg_word_length") | |
max_avg_word_length_slider = gr.Slider(0, 20, value=10, step=1, label="max_avg_word_length") | |
with gr.Row(): | |
max_symbol_word_ratio_slider = gr.Slider(0, 1, value=0.1, step=0.05, label="max_symbol_word_ratio") | |
max_bullet_lines_ratio_slider = gr.Slider(0, 1, value=0.9, step=0.05, label="max_bullet_lines_ratio") | |
max_ellipsis_lines_ratio_slider = gr.Slider(0, 1, value=0.3, step=0.05, label="max_ellipsis_lines_ratio") | |
max_non_alpha_words_ratio_slider = gr.Slider(0, 1, value=0.8, step=0.05, label="max_non_alpha_words_ratio") | |
with gr.Row(): | |
min_stop_words_slider = gr.Slider(0, 10, value=2, step=1, label="min_stop_words") | |
stop_words_textbox = gr.Textbox("the, be, to, of, and, that, have, with", label="stop_words") | |
stop_words_textbox.prepare_parameter = prepare_as_list_or_none | |
gopher_filtering_quality_checkbox.change(lambda visible: gr.Accordion(visible=visible), inputs=gopher_filtering_quality_checkbox, outputs=acc) | |
gopher_filtering_quality_parameters_components = [language_dropdown2, min_doc_words_slider, max_doc_words_slider, min_avg_word_length_slider, max_avg_word_length_slider, max_symbol_word_ratio_slider, max_bullet_lines_ratio_slider, max_ellipsis_lines_ratio_slider, max_non_alpha_words_ratio_slider, min_stop_words_slider, stop_words_textbox] | |
steps_parameters_components = [ | |
url_filtering_parameters_components, | |
text_extraction_parameters_components, | |
language_filtering_parameters_components, | |
gopher_filtering_repetitions_parameters_components, | |
gopher_filtering_quality_parameters_components, | |
c4_filters_parameters_components, | |
custom_filters_parameters_components, | |
pii_removal_parameters_components | |
] | |
with gr.Column(): | |
with gr.Tabs(elem_classes="scollabe_tabs"): | |
with gr.Tab("Output (and % of data)") as output_tab: | |
output_dataframe = gr.DataFrame(datatype="markdown") | |
with gr.Tab("Excluded (and % of data)") as excluded_tab: | |
with gr.Tabs(elem_classes="scollabe_tabs"): | |
excluded_dataframes: dict[Type, gr.DataFrame] = {} | |
excluded_tabs: dict[Type, gr.Tab] = {} | |
for step in steps: | |
if issubclass(step, BaseFilter) and step is not URLFilter: | |
with gr.Tab(step.__name__ + " (and % of data)") as t: | |
excluded_dataframes[step] = gr.DataFrame(datatype="markdown") | |
excluded_tabs[step] = t | |
with gr.Tab("Python code") as code_tab: | |
python_code_markdown = gr.Markdown(DEFAULT_CODE) | |
gr.Markdown("_powered by [datatrove](https://github.com/huggingface/datatrove)_") | |
def show_block_ui(i, current_state: dict): | |
if i == current_state.get("selected_block"): | |
i = None | |
return {**{block_ui: gr.Column(visible=(j == i)) for j, block_ui in enumerate(blocks_uis)}, state: {"selected_block": i}} | |
for i, button in enumerate(gallery_image_buttons): | |
button.click(partial(show_block_ui, i), inputs=[state], outputs=blocks_uis + [state]) | |
inputs = [ | |
url_filtering_checkbox, | |
text_extraction_checkbox, | |
language_filtering_checkbox, | |
gopher_filtering_repetitions_checkbox, | |
gopher_filtering_quality_checkbox, | |
c4_filters_checkbox, | |
custom_filters_checkbox, | |
pii_removal_checkbox | |
] + sum(steps_parameters_components, []) | |
def view_pipeline_results(*args): | |
enable_steps, steps_parameters = args[:len(steps)], args[len(steps):] | |
steps_parameters_iter = iter(steps_parameters) | |
steps_parameters = [ | |
{ | |
parameters_component.label: parameters_component.prepare_parameter(parameter) if hasattr(parameters_component, "prepare_parameter") else parameter | |
for parameters_component, parameter in zip(step_parameters_components, steps_parameters_iter) | |
} | |
for step_parameters_components in steps_parameters_components | |
] | |
default_steps_parameters = [ | |
{ | |
parameters_component.label: parameters_component.prepare_parameter(parameters_component.value) if hasattr(parameters_component, "prepare_parameter") else parameters_component.value | |
for parameters_component in step_parameters_components | |
} | |
for step_parameters_components in steps_parameters_components | |
] | |
yield { | |
python_code_markdown: dedent( | |
""" | |
```python | |
from datatrove.executor.local import LocalPipelineExecutor | |
from datatrove.pipeline.extractors import Trafilatura | |
from datatrove.pipeline.filters import ( | |
C4QualityFilter, | |
FineWebQualityFilter, | |
GopherQualityFilter, | |
GopherRepetitionFilter, | |
LanguageFilter, | |
URLFilter, | |
) | |
from datatrove.pipeline.formatters import PIIFormatter | |
from datatrove.pipeline.readers import WarcReader | |
""" | |
).strip() + ( | |
"\n\n" | |
"pipeline_executor = LocalPipelineExecutor(\n" | |
" pipeline=[\n" | |
f' WarcReader("s3://commoncrawl/crawl-data/{DUMP_TO_PROCESS}/segments", glob_pattern="*/warc/*"),\n' | |
) + ",\n".join([ | |
" " + step.__name__ + "(" + ", ".join(arg + "=" + str(value) for arg, value in step_parameters.items() if value != default_step_parameters[arg] and arg != "exclusion_writer") + ")" | |
for step, step_parameters, default_step_parameters, enable_step in zip(steps, steps_parameters, default_steps_parameters, enable_steps) | |
if enable_step | |
]) + ( | |
"\n" | |
" ]\n" | |
")" | |
) + dedent( | |
""" | |
pipeline_executor.run() | |
``` | |
""" | |
) | |
} | |
class ExclusionWriter: | |
def __init__(self) -> None: | |
self.docs: list[Document] = [] | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
return | |
def write(self, doc, rank): | |
self.docs.append(doc) | |
steps_to_run = [ | |
step(**step_parameters, **({"exclusion_writer": ExclusionWriter()} if step in excluded_dataframes else {})) | |
for step, step_parameters, enable_step in zip(steps, steps_parameters, enable_steps) | |
if enable_step | |
] | |
output_docs: list[Document] = [] | |
num_warc_samples = 0 | |
timeout_time = time.time() + TIMEOUT | |
def increment_num_warc_samples(data, rank, world_size, num_warc_samples_per_doc=1): | |
nonlocal num_warc_samples | |
for x in data: | |
num_warc_samples += num_warc_samples_per_doc | |
yield x | |
def check_timeout(data, rank, world_size): | |
for x in data: | |
if time.time() > timeout_time: | |
gr.Info("Pipeline timed out") | |
break | |
yield x | |
if steps_parameters[:2] == default_steps_parameters[:2] and all(enable_steps[:2]): | |
pipeline_executor = LocalPipelineExecutor( | |
pipeline=[ | |
JsonlReader(data_folder=f"output_text_extraction-full/base_processing/output/{DUMP_TO_PROCESS}", glob_pattern="*.jsonl.gz"), | |
partial(increment_num_warc_samples, num_warc_samples_per_doc=2000 / 1687), | |
check_timeout | |
] + steps_to_run[2:] + [ | |
lambda data, rank, world_size: islice(data, 100), | |
lambda data, rank, world_size: map(output_docs.append, data) | |
], | |
logging_dir="logs", | |
skip_completed=False | |
) | |
else: | |
pipeline_executor = LocalPipelineExecutor( | |
pipeline=[ | |
WarcReader(data_folder="data", glob_pattern="*.warc.gz"), | |
increment_num_warc_samples, | |
check_timeout | |
] + steps_to_run + [ | |
lambda data, rank, world_size: islice(data, 100), | |
lambda data, rank, world_size: map(output_docs.append, data) | |
], | |
logging_dir="logs", | |
skip_completed=False | |
) | |
from threading import Thread | |
thread = Thread(target=pipeline_executor.run) | |
thread.start() | |
while thread.is_alive(): | |
thread.join(timeout=1) | |
if num_warc_samples: | |
yield { | |
output_tab: gr.Tab(f"Output ({len(output_docs)/num_warc_samples*100:.03f}%)"), | |
excluded_tab: gr.Tab(f"Excluded ({100 - len(output_docs)/num_warc_samples*100:.03f}%)"), | |
output_dataframe: pd.DataFrame({"text": [doc.text for doc in output_docs]}), | |
**{ | |
excluded_dataframes[type(step_to_run)]: pd.DataFrame({"text": [doc.text for doc in step_to_run.exclusion_writer.docs]}) | |
for step_to_run in pipeline_executor.pipeline | |
if isinstance(step_to_run, BaseFilter) and type(step_to_run) in excluded_dataframes | |
}, | |
**{ | |
excluded_tabs[type(step_to_run)]: gr.Tab(f"{type(step_to_run).__name__} ({len(step_to_run.exclusion_writer.docs)/num_warc_samples*100:.03f}%)") | |
for step_to_run in pipeline_executor.pipeline | |
if isinstance(step_to_run, BaseFilter) and type(step_to_run) in excluded_dataframes | |
}, | |
} | |
else: | |
yield { | |
output_tab: gr.Tab("Output (loading...)"), | |
excluded_tab: gr.Tab("Excluded (loading...)"), | |
**{ | |
excluded_dataframes[type(step_to_run)]: pd.DataFrame({"text": []}) | |
for step_to_run in pipeline_executor.pipeline | |
if isinstance(step_to_run, BaseFilter) and type(step_to_run) in excluded_dataframes | |
}, | |
**{ | |
excluded_tabs[type(step_to_run)]: gr.Tab(f"{type(step_to_run).__name__}") | |
for step_to_run in pipeline_executor.pipeline | |
if isinstance(step_to_run, BaseFilter) and type(step_to_run) in excluded_dataframes | |
}, | |
} | |
yield { | |
output_tab: gr.Tab(f"Output ({len(output_docs)/num_warc_samples*100:.03f}%)"), | |
excluded_tab: gr.Tab(f"Excluded ({100 - len(output_docs)/num_warc_samples*100:.03f}%)"), | |
output_dataframe: pd.DataFrame({"text": [doc.text for doc in output_docs]}), | |
**{ | |
excluded_dataframes[type(step_to_run)]: pd.DataFrame({"text": [doc.text for doc in step_to_run.exclusion_writer.docs]}) | |
for step_to_run in pipeline_executor.pipeline | |
if isinstance(step_to_run, BaseFilter) and type(step_to_run) in excluded_dataframes | |
}, | |
**{ | |
excluded_tabs[type(step_to_run)]: gr.Tab(f"{type(step_to_run).__name__} ({len(step_to_run.exclusion_writer.docs)/num_warc_samples*100:.03f}%)") | |
for step_to_run in pipeline_executor.pipeline | |
if isinstance(step_to_run, BaseFilter) and type(step_to_run) in excluded_dataframes | |
}, | |
} | |
view_pipeline_results_button.click(view_pipeline_results, inputs=inputs, outputs=[output_tab, output_dataframe, excluded_tab, python_code_markdown] + list(excluded_dataframes.values()) + list(excluded_tabs.values())) | |
if __name__ == "__main__": | |
demo.launch() | |