File size: 1,495 Bytes
ef09277
 
 
 
 
 
61bfd6f
ef09277
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61bfd6f
d916e25
ef09277
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
import sys
sys.path.append(sys.path[0].replace('scripts', ''))
from urllib.request import urlretrieve
import pandas as pd

from config.data_paths import PROCESSED_DATA_PATH
import re

from scripts.utils import load_config

PROMPTS_URL = load_config()['data'].get('prompts_corpus_url', 'https://huggingface.co/datasets/poloclub/diffusiondb/resolve/main/metadata.parquet')

def preprocess_text(text: str) -> str:
    """
    Text preprocessing function.
    Args:
        text: Raw text prompt.
    Returns:
        Preprocessed text.
    """
    text = text.strip()                 # Remove leading/trailing whitespace
    text = re.sub(r'\s+', ' ', text)    # Replace multiple spaces with a single space
    return text

def clean_corpus():
    """
    Utility function to clean and preprocess the prompt corpus.
    """
    if not os.path.isfile(os.path.join(PROCESSED_DATA_PATH, 'prompt_corpus_clean.parquet')): # to speed up the process
        os.makedirs(PROCESSED_DATA_PATH, exist_ok=True)
        df = pd.read_parquet(PROMPTS_URL).sample(5000, random_state=123)
        assert 'prompt' in df.columns, "Parquet file must contain a 'prompt' column."
        df = df[df['prompt'].notna()][['prompt']]     # drop missing rows
        df['prompt'] = df['prompt'].apply(preprocess_text)              # preprocess each prompt
        df = df.drop_duplicates()                   # drop duplicates

        df.to_parquet(os.path.join(PROCESSED_DATA_PATH, 'prompt_corpus_clean.parquet'))