diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..212c21d6c94a3adda0a0bc2dd800f72d6e422016 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.vocab filter=lfs diff=lfs merge=lfs -text +texts/*.txt filter=lfs diff=lfs merge=lfs -text +*.arpa* filter=lfs diff=lfs merge=lfs -text +kenlm/*.bin filter=lfs diff=lfs merge=lfs -text +kenlm/*.arpa filter=lfs diff=lfs merge=lfs -text +samples/*.jsonl filter=lfs diff=lfs merge=lfs -text +*.jsonl filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..9bbb8ec2cc9ba27a85312ab631b344968a13e606 --- /dev/null +++ b/.gitignore @@ -0,0 +1,17 @@ +tmp/ +__pycache__/ +*.pyc +.ipynb_checkpoints + + +samples/restricted* +samples/*.json* +kenlm/wikipedia/* +!kenlm/wikipedia/.keep +kenlm/harmful/* +!kenlm/harmful/.keep +spm/wikipedia/* +!spm/wikipedia/.keep +spm/*.txt +texts/* +!texts/.keep \ No newline at end of file diff --git a/README.md b/README.md index 154df8298fab5ecf322016157858e08cd1bccbe1..9ddeb9ef308bb6658c810301d89adbbc4cd1e764 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,39 @@ --- license: apache-2.0 --- + +# Perplexity tools + +## 1. Create samples from `clean_json_3` sources + +Between 1k and 1M documents. Read [samples/README.md](./samples/README.md). Output files must be prefixed by `doc_type` and suffixed by language code (2 letters). For example: + +```bash +$ cat /nfsmounts/datastore/ncc_corpus/mimir/jsonl_2/nrk/nrk-articles.jsonl | shuf -n 100000 > samples/restricted-newspapers_nrk_no.json +``` + +## 2. Create the perplexity scores for each file + +Example of how to create scores only for `doc_type` `restricted-newspapers_*` samples: + +```bash +$ ls samples/restricted-newspapers_* | parallel --lb --jobs 5 python samples_scores.py {} --output_path scores/ --jobs 15 +``` + +## 3. Create the quartiles CSV needed for segmenting and downsamplig + +The different `doc_type`s will be grouped together. By passing the flag `--group_by_prefix_lang`, the grouping will happen on the pair `doc_type` prefix and language code, e.g., `wikipedia_en`. + +Different downsampling ratios can be specified by using the `--sampling_ratio_per_lang` flag. For `mimir-base`, the downsampling by language is defined as follows: `"da:0.23,en:0.21,sv:0.08,is:0.50"`. + +```bash +$ python samples_quartiles.py scores/ --group_by_prefix_lang --sampling_ratio_per_lang "da:0.23,en:0.21,sv:0.08,is:0.50" --output_file csv/base-perplexity_quartiles_sampling.csv +``` + +For `mimir-extended`, the downsampling by language is defined as follows: `"da:0.43,en:0.81,sv:0.15,code:0.62"`. + +```bash +$ python samples_quartiles.py scores/ --group_by_prefix_lang --sampling_ratio_per_lang "da:0.43,en:0.81,sv:0.15,code:0.62" --output_file csv/extended-perplexity_quartiles_sampling.csv --overwrite_prefix_lang "starcoder_en:starcode_code" +``` + +More information in the [spreadsheet](https://docs.google.com/spreadsheets/d/108oGVVN-Ml-TDN59UXR96oeBBt2FbgT81zt8_1y9PUw/edit?usp=sharing). \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/csv/base-perplexity_quartiles_sampling.csv b/csv/base-perplexity_quartiles_sampling.csv new file mode 100644 index 0000000000000000000000000000000000000000..16602387836aa08b82dd2eb0e5e60f233d06eb91 --- /dev/null +++ b/csv/base-perplexity_quartiles_sampling.csv @@ -0,0 +1,33 @@ +doc_type,model,language,reject,bad,medium,good,norm,mean,std +books,books,no,542.15,301.25,219.3,165.12,0.0032422660862847633,208.18464621605895,68.02897458931068 +culturax,wikipedia,nn,1113.2,753.4,559.9,387.7,0.001172357337862289,487.27059437715525,185.90322713836343 +culturax,wikipedia,sv,1118.6,772.2,606.9,479.8,0.01968171485234145,580.0945047395821,142.99911605358275 +culturax,wikipedia,da,1012.9,648.2,503.3,397.98,0.007997295965244292,488.615463864415,124.17368632962524 +digimanus,wikipedia,no,1991.88,1226.65,989.1,830.35,0.0011146154086008851,974.7133669943673,209.08555530030617 +culturax,wikipedia,no,1073.1,691.1,538.2,430.0,0.0017538216816248486,523.6960713940705,130.62730440702228 +culturax,wikipedia,is,1420.0,884.5,720.2,594.5,0.0030935154995326906,693.7606785221377,147.6241796866134 +evalueringsrapport,maalfrid,no,268.25,163.5,127.8,98.3,0.006540788722418088,117.29318501940242,34.47568292096079 +hplt,wikipedia,nn,1539.1,980.6,772.7,627.5,0.0012826369023540814,752.0725635933572,179.13196906762977 +lovdata,maalfrid,no,457.9,162.9,84.6,41.6,0.0038894207845140477,96.06375056993284,58.30277337274196 +maalfrid,maalfrid,no,686.5,286.9,164.8,87.3,0.0022814356724527207,164.0258389923656,101.07016579025363 +hplt,wikipedia,da,1445.5,829.3,616.3,493.5,0.00597386636355673,630.7049612170936,168.77191092534918 +book,books,no,636.48,302.58,187.4,67.0,0.002034229155801576,158.1210630456195,109.45691866057511 +hplt,wikipedia,sv,1398.0,910.9,715.8,578.5,0.0173199443667263,698.8065459625257,165.03293101814995 +hplt,wikipedia,no,1589.0,880.7,668.5,532.6,0.0013206924407238364,671.3073940020074,174.52833317000255 +newspapers,newspapers,nn,1685.4,1221.9,1005.4,825.2,0.0011282397163826917,951.0683339330576,197.39448542294457 +newspaper,newspapers,no,2308.6,792.3,475.2,307.9,0.0009454270671058767,526.1389696700705,244.38110894007406 +parlamint,maalfrid,no,129.23,104.0,93.8,84.6,0.02105587174354365,89.24246433500929,10.099393230392014 +newspapers,newspapers,no,782.3,466.7,336.0,243.5,0.002096701749929567,326.7656126928853,108.47162732564873 +wikipedia,wikipedia,da,1226.31,470.7,278.7,127.0,0.006116042100206995,272.5462428872027,159.55477781562297 +wikipedia,wikipedia,is,1893.3,740.7,449.1,174.5,0.001640993616891793,429.6854374438017,283.9768832443661 +wikipedia,wikipedia,nn,1159.86,494.45,283.1,123.6,0.0013200962342698906,280.91195392289364,167.82742834163992 +wikipedia,wikipedia,no,2058.62,612.2,324.6,139.3,0.0009966961122328344,363.387061861549,229.2323512781706 +slimpajama,wikipedia,en,2259.2,756.5,534.4,418.5,0.006212514831977225,569.5492667529695,179.9279253054439 +wikipedia,wikipedia,sv,1586.56,521.5,304.0,165.4,0.016951427796527165,325.8191384990417,163.13795554088844 +wikipedia,wikipedia,en,1815.4,671.6,455.7,331.2,0.006112968939492834,470.655891042871,184.96531992400435 +hplt,wikipedia,is,2310.06,1484.7,1160.3,921.3,0.001632796440658896,1119.008609637535,278.2396677657607 +pg19,wikipedia,en,865.84,540.3,473.2,419.1,0.017132607020012576,460.9763713901977,63.76307180686858 +starcoder,wikipedia,en,6898.5,2724.5,1603.4,972.4,0.0012712203723443047,1734.1527299358695,858.6110807589087 +slimpajama,wikipedia,no,2259.2,756.5,534.4,418.5,0.006212514831977225,569.5492667529695,179.9279253054439 +starcoder,wikipedia,no,6898.5,2724.5,1603.4,972.4,0.0012712203723443047,1734.1527299358695,858.6110807589087 +pg19,wikipedia,no,865.84,540.3,473.2,419.1,0.017132607020012576,460.9763713901977,63.76307180686858 \ No newline at end of file diff --git a/csv/extended-perplexity_quartiles_sampling.csv b/csv/extended-perplexity_quartiles_sampling.csv new file mode 100644 index 0000000000000000000000000000000000000000..438364d47d9618847a1cd42e8d6f54704995226c --- /dev/null +++ b/csv/extended-perplexity_quartiles_sampling.csv @@ -0,0 +1,37 @@ +doc_type,model,language,reject,bad,medium,good,norm,mean,std +books,books,no,542.15,301.25,219.3,165.12,0.0032422660862847633,208.18464621605895,68.02897458931068 +culturax,wikipedia,nn,1113.2,753.4,559.9,387.7,0.001172357337862289,487.27059437715525,185.90322713836343 +culturax,wikipedia,sv,1118.6,772.2,606.9,479.8,0.01049691458791544,580.0945047395821,142.99911605358275 +culturax,wikipedia,da,1012.9,648.2,503.3,397.98,0.004277623423270203,488.615463864415,124.17368632962524 +digimanus,wikipedia,no,1991.88,1226.65,989.1,830.35,0.0011146154086008851,974.7133669943673,209.08555530030617 +culturax,wikipedia,no,1073.1,691.1,538.2,430.0,0.0017538216816248486,523.6960713940705,130.62730440702228 +culturax,wikipedia,is,1420.0,884.5,720.2,594.5,0.0015467577497663453,693.7606785221377,147.6241796866134 +evalueringsrapport,maalfrid,no,268.25,163.5,127.8,98.3,0.006540788722418088,117.29318501940242,34.47568292096079 +hplt,wikipedia,nn,1539.1,980.6,772.7,627.5,0.0012826369023540814,752.0725635933572,179.13196906762977 +lovdata,maalfrid,no,457.9,162.9,84.6,41.6,0.0038894207845140477,96.06375056993284,58.30277337274196 +maalfrid,maalfrid,no,686.5,286.9,164.8,87.3,0.0022814356724527207,164.0258389923656,101.07016579025363 +hplt,wikipedia,da,1445.5,829.3,616.3,493.5,0.0031953238688791816,630.7049612170936,168.77191092534918 +book,books,no,636.48,302.58,187.4,67.0,0.002034229155801576,158.1210630456195,109.45691866057511 +hplt,wikipedia,sv,1398.0,910.9,715.8,578.5,0.009237303662254026,698.8065459625257,165.03293101814995 +hplt,wikipedia,no,1589.0,880.7,668.5,532.6,0.0013206924407238364,671.3073940020074,174.52833317000255 +newspapers,newspapers,nn,1685.4,1221.9,1005.4,825.2,0.0011282397163826917,951.0683339330576,197.39448542294457 +newspaper,newspapers,no,2308.6,792.3,475.2,307.9,0.0009454270671058767,526.1389696700705,244.38110894007406 +parlamint,maalfrid,no,129.23,104.0,93.8,84.6,0.02105587174354365,89.24246433500929,10.099393230392014 +newspapers,newspapers,no,782.3,466.7,336.0,243.5,0.002096701749929567,326.7656126928853,108.47162732564873 +wikipedia,wikipedia,da,1226.31,470.7,278.7,127.0,0.003271371355924672,272.5462428872027,159.55477781562297 +wikipedia,wikipedia,is,1893.3,740.7,449.1,174.5,0.0008204968084458965,429.6854374438017,283.9768832443661 +wikipedia,wikipedia,nn,1159.86,494.45,283.1,123.6,0.0013200962342698906,280.91195392289364,167.82742834163992 +wikipedia,wikipedia,no,2058.62,612.2,324.6,139.3,0.0009966961122328344,363.387061861549,229.2323512781706 +slimpajama,wikipedia,en,2259.2,756.5,534.4,418.5,0.0016106519934755766,569.5492667529695,179.9279253054439 +wikipedia,wikipedia,sv,1586.56,521.5,304.0,165.4,0.009040761491481156,325.8191384990417,163.13795554088844 +wikipedia,wikipedia,en,1815.4,671.6,455.7,331.2,0.0015848437991277716,470.655891042871,184.96531992400435 +hplt,wikipedia,is,2310.06,1484.7,1160.3,921.3,0.000816398220329448,1119.008609637535,278.2396677657607 +pg19,wikipedia,en,865.84,540.3,473.2,419.1,0.004441787005188445,460.9763713901977,63.76307180686858 +starcoder,wikipedia,code,6898.5,2724.5,1603.4,972.4,0.0004305746422456516,1734.1527299358695,858.6110807589087 +restricted-newspapers,newspapers,no,847.7,451.7,328.5,246.5,0.002248478883149024,325.7155732204811,102.50329419364242 +restricted-books,books,no,636.88,375.5,282.8,216.8,0.0028282201514638694,272.19155841413874,81.36986186892527 +restricted-book,books,no,569.8,365.9,281.7,218.6,0.0030429861768025046,267.8089800338991,74.79791679414626 +slimpajama,wikipedia,no,2259.2,756.5,534.4,418.5,0.0016106519934755766,569.5492667529695,179.9279253054439 +starcoder,wikipedia,no,6898.5,2724.5,1603.4,972.4,0.0004305746422456516,1734.1527299358695,858.6110807589087 +starcoder,wikipedia,code,6898.5,2724.5,1603.4,972.4,0.0004305746422456516,1734.1527299358695,858.6110807589087 +pg19,wikipedia,no,865.84,540.3,473.2,419.1,0.004441787005188445,460.9763713901977,63.76307180686858 diff --git a/download_all.sh b/download_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..8e73bb5a79d8082ed9f9b405e3e81e1d5fc840b4 --- /dev/null +++ b/download_all.sh @@ -0,0 +1,40 @@ +mkdir kenlm +mv *arpa* kenlm/ + +mkdir spm +mv *.model spm/ +mv *.vocab spm/ + +mkdir kenlm/harmful +wget -O kenlm/harmful/da.arpa https://huggingface.co/oscar-corpus/harmful-kenlms/resolve/main/da.arpa +wget -O kenlm/harmful/da.bin https://huggingface.co/oscar-corpus/harmful-kenlms/resolve/main/da.binary +wget -O kenlm/harmful/sv.arpa https://huggingface.co/oscar-corpus/harmful-kenlms/resolve/main/sv.arpa +wget -O kenlm/harmful/sv.bin https://huggingface.co/oscar-corpus/harmful-kenlms/resolve/main/sv.binary +wget -O kenlm/harmful/is.arpa https://huggingface.co/oscar-corpus/harmful-kenlms/resolve/main/is.arpa +wget -O kenlm/harmful/is.bin https://huggingface.co/oscar-corpus/harmful-kenlms/resolve/main/is.binary +wget -O kenlm/harmful/no.arpa https://huggingface.co/oscar-corpus/harmful-kenlms/resolve/main/no.arpa +wget -O kenlm/harmful/no.bin https://huggingface.co/oscar-corpus/harmful-kenlms/resolve/main/no.binary +wget -O kenlm/harmful/en.arpa https://huggingface.co/oscar-corpus/harmful-kenlms/resolve/main/en.arpa +wget -O kenlm/harmful/en.bin https://huggingface.co/oscar-corpus/harmful-kenlms/resolve/main/en.binary + +mkdir kenlm/wikipedia +wget --header="Authorization: Bearer $(cat $HOME/.cache/huggingface/token)" -O kenlm/wikipedia/da.arpa.bin https://huggingface.co/uonlp/kenlm/resolve/main/wikipedia_20230501/da.arpa.bin +wget --header="Authorization: Bearer $(cat $HOME/.cache/huggingface/token)" -O kenlm/wikipedia/sv.arpa.bin https://huggingface.co/uonlp/kenlm/resolve/main/wikipedia_20230501/sv.arpa.bin +wget --header="Authorization: Bearer $(cat $HOME/.cache/huggingface/token)" -O kenlm/wikipedia/is.arpa.bin https://huggingface.co/uonlp/kenlm/resolve/main/wikipedia_20230501/is.arpa.bin +wget --header="Authorization: Bearer $(cat $HOME/.cache/huggingface/token)" -O kenlm/wikipedia/no.arpa.bin https://huggingface.co/uonlp/kenlm/resolve/main/wikipedia_20230501/no.arpa.bin +wget --header="Authorization: Bearer $(cat $HOME/.cache/huggingface/token)" -O kenlm/wikipedia/nn.arpa.bin https://huggingface.co/uonlp/kenlm/resolve/main/wikipedia_20230501/nn.arpa.bin +wget --header="Authorization: Bearer $(cat $HOME/.cache/huggingface/token)" -O kenlm/wikipedia/en.arpa.bin https://huggingface.co/uonlp/kenlm/resolve/main/wikipedia_20230501/en.arpa.bin + +mkdir spm/wikipedia +wget --header="Authorization: Bearer $(cat $HOME/.cache/huggingface/token)" -O spm/wikipedia/da.sp.model https://huggingface.co/uonlp/kenlm/resolve/main/wikipedia_20230501/da.sp.model +wget --header="Authorization: Bearer $(cat $HOME/.cache/huggingface/token)" -O spm/wikipedia/sv.sp.model https://huggingface.co/uonlp/kenlm/resolve/main/wikipedia_20230501/sv.sp.model +wget --header="Authorization: Bearer $(cat $HOME/.cache/huggingface/token)" -O spm/wikipedia/is.sp.model https://huggingface.co/uonlp/kenlm/resolve/main/wikipedia_20230501/is.sp.model +wget --header="Authorization: Bearer $(cat $HOME/.cache/huggingface/token)" -O spm/wikipedia/no.sp.model https://huggingface.co/uonlp/kenlm/resolve/main/wikipedia_20230501/no.sp.model +wget --header="Authorization: Bearer $(cat $HOME/.cache/huggingface/token)" -O spm/wikipedia/nn.sp.model https://huggingface.co/uonlp/kenlm/resolve/main/wikipedia_20230501/nn.sp.model +wget --header="Authorization: Bearer $(cat $HOME/.cache/huggingface/token)" -O spm/wikipedia/en.sp.model https://huggingface.co/uonlp/kenlm/resolve/main/wikipedia_20230501/en.sp.model +wget --header="Authorization: Bearer $(cat $HOME/.cache/huggingface/token)" -O spm/wikipedia/da.sp.vocab https://huggingface.co/uonlp/kenlm/resolve/main/wikipedia_20230501/da.sp.vocab +wget --header="Authorization: Bearer $(cat $HOME/.cache/huggingface/token)" -O spm/wikipedia/sv.sp.vocab https://huggingface.co/uonlp/kenlm/resolve/main/wikipedia_20230501/sv.sp.vocab +wget --header="Authorization: Bearer $(cat $HOME/.cache/huggingface/token)" -O spm/wikipedia/is.sp.vocab https://huggingface.co/uonlp/kenlm/resolve/main/wikipedia_20230501/is.sp.vocab +wget --header="Authorization: Bearer $(cat $HOME/.cache/huggingface/token)" -O spm/wikipedia/no.sp.vocab https://huggingface.co/uonlp/kenlm/resolve/main/wikipedia_20230501/no.sp.vocab +wget --header="Authorization: Bearer $(cat $HOME/.cache/huggingface/token)" -O spm/wikipedia/nn.sp.vocab https://huggingface.co/uonlp/kenlm/resolve/main/wikipedia_20230501/nn.sp.vocab +wget --header="Authorization: Bearer $(cat $HOME/.cache/huggingface/token)" -O spm/wikipedia/en.sp.vocab https://huggingface.co/uonlp/kenlm/resolve/main/wikipedia_20230501/en.sp.vocab diff --git a/histograms.py b/histograms.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4943df6b8b3de09460a4debd4622b803a780c0 --- /dev/null +++ b/histograms.py @@ -0,0 +1,104 @@ +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns +import json +import argparse +import os +from scipy.stats import gaussian_kde +import numpy as np + +def get_model_for(doc_type: str, override_model: str) -> str: + """Returns model type or the override model if specified""" + if override_model: + return override_model + doc_type = doc_type.split("_", 1)[0] + if doc_type in ("book", "books", "pg19"): + return "books_pp" + elif doc_type in ("culturax", "slimpajama", "wikipedia", "digimanus"): + return "wikipedia_pp" + elif doc_type in ("newspaper", "newspapers"): + return "newspapers_pp" + elif doc_type in ("evalueringsrapport", "lovdata", "maalfrid", "parlamint"): + return "maalfrid_pp" + else: + return "wikipedia_pp" + +def load_data(files): + all_data = [] + for file_path in files: + with open(file_path, 'r') as file: + lines = file.readlines() + data = [json.loads(line) for line in lines] + all_data.extend(data) + return pd.DataFrame(all_data) + +def plot_histograms(files, output_folder, xlim, override_model): + df = load_data(files) + doc_types = df['doctype'].unique() + fig, axes = plt.subplots(len(doc_types), 1, figsize=(12, 4 * len(doc_types)), squeeze=False) + + # Set up a color palette + palette = sns.color_palette("husl", len(doc_types)) + + for i, doc_type in enumerate(doc_types): + ax = axes[i, 0] + group = df[df['doctype'] == doc_type] + languages = group['lang'].unique() + + # Prepare a unique color for each language within the document type + colors = sns.color_palette("husl", len(languages)) + + for j, lang in enumerate(languages): + lang_group = group[group['lang'] == lang] + perplexity_model = get_model_for(doc_type, override_model) + perplexity_values = lang_group['perplexities'].apply(lambda x: x[perplexity_model]).values + + series_color = colors[j] + + # Plot histogram with lighter color + sns.histplot(perplexity_values, ax=ax, color=series_color, alpha=0.3, element="step", fill=True, stat="density", binwidth=30) + + # Plot KDE without filling + sns.kdeplot(perplexity_values, ax=ax, bw_adjust=2, color=series_color, label=f"{lang} - {doc_type} ({perplexity_model})", linewidth=1.5) + + + kde = gaussian_kde(perplexity_values) + x_range = np.linspace(0, xlim, 1000) + y_values = kde.evaluate(x_range) + + quartiles = np.quantile(perplexity_values, [0.25, 0.5, 0.75]) + quartile_labels = ["Q1", "Q2", "Q3"] + for q, quartile in enumerate(quartiles): + idx = (np.abs(x_range-quartile)).argmin() + y_quartile = y_values[idx] + ax.plot([quartile, quartile], [0, y_quartile], color=series_color, linestyle='--', linewidth=1) + ax.text(quartile, y_quartile, f'{quartile_labels[q]}: {quartile:.2f}', verticalalignment='bottom', horizontalalignment='right', color=series_color, fontsize=6) + + ax.set_title(f'Document Type: {doc_type} ({perplexity_model})') + ax.set_xlabel('Perplexity Value') + ax.set_ylabel('Density') + ax.legend() + ax.set_xlim(left=0, right=xlim) + + plt.tight_layout() + output_filename = os.path.join(output_folder, "all_doc_types_plots.png") + plt.savefig(output_filename, dpi=300) + plt.close(fig) + print(f"All document type plots saved to {output_filename}") + +def main(): + parser = argparse.ArgumentParser(description="Plot histograms from JSON lines files.") + parser.add_argument('files', nargs='+', help="Path to the JSON lines files") + parser.add_argument('-o', '--output_folder', default=".", help="Output folder for the plots") + parser.add_argument('--xlim', type=int, default=2500, help="Maximum x-axis limit for the plots") + parser.add_argument('--model', default="", help="Override the perplexity model for all plots") + + args = parser.parse_args() + + if not os.path.exists(args.output_folder): + os.makedirs(args.output_folder, exist_ok=True) + + plot_histograms(args.files, args.output_folder, args.xlim, args.model) + +if __name__ == "__main__": + main() diff --git a/kenlm/books.norm.arpa.bin b/kenlm/books.norm.arpa.bin new file mode 100644 index 0000000000000000000000000000000000000000..a410aa485fbf0dd8bfd69bae3d4470a822b50f85 --- /dev/null +++ b/kenlm/books.norm.arpa.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bc2058f3fe709dcdc9e02c3094d9dba6e1d9e2846e3064fd597b632bdda7424f +size 26787259332 diff --git a/kenlm/books.norm.arpa.zip b/kenlm/books.norm.arpa.zip new file mode 100644 index 0000000000000000000000000000000000000000..69e3d7d09db6d3ebddc0da518d8cb8edefd5d3b5 --- /dev/null +++ b/kenlm/books.norm.arpa.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:16934f2f95c19d22bf681552c5b667483e80915a72ed559e954914f492513604 +size 14951532895 diff --git a/kenlm/books.norm.sp.arpa.bin b/kenlm/books.norm.sp.arpa.bin new file mode 100644 index 0000000000000000000000000000000000000000..bb4e9e18b51c2b721d53061daaec76d689e0155a --- /dev/null +++ b/kenlm/books.norm.sp.arpa.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:582210ccef9a44feb2dde5029e3b02986ba3bb50d06152e2850a863fee8df16d +size 27269792294 diff --git a/kenlm/books.norm.sp.arpa.zip b/kenlm/books.norm.sp.arpa.zip new file mode 100644 index 0000000000000000000000000000000000000000..05a3b1c59c36b1225eb044e5b7868c4404250816 --- /dev/null +++ b/kenlm/books.norm.sp.arpa.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c67bb924e8d2e0515037b1aca7c381267e7363d95ae4c5a773ae8517f9c34f81 +size 14081165146 diff --git a/kenlm/harmful/.keep b/kenlm/harmful/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/kenlm/maalfrid.norm.arpa b/kenlm/maalfrid.norm.arpa new file mode 100644 index 0000000000000000000000000000000000000000..3be72a73edd467b7777bd8bb42408754f0a3613e --- /dev/null +++ b/kenlm/maalfrid.norm.arpa @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9964b5a0a25e8d8f352bd85ee3de5cea80cd56cb033f4831c83e450ef42ee9b2 +size 14095675125 diff --git a/kenlm/maalfrid.norm.arpa.bin b/kenlm/maalfrid.norm.arpa.bin new file mode 100644 index 0000000000000000000000000000000000000000..ae65a1c9a5df17c840a7010183ac98e35c93be1e --- /dev/null +++ b/kenlm/maalfrid.norm.arpa.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4468f452cd224c25a7ab125f930692d415ba9a44564b6d8590ae60a697021ff8 +size 6334870758 diff --git a/kenlm/maalfrid.norm.sp.arpa b/kenlm/maalfrid.norm.sp.arpa new file mode 100644 index 0000000000000000000000000000000000000000..6e8a0026876c1a95788a8c33fe10ed1e9277a541 --- /dev/null +++ b/kenlm/maalfrid.norm.sp.arpa @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:12acfaf2360adec24e0456c0c9ab2a3199eda397dddb8c6b194ac7376d0811d5 +size 15096276243 diff --git a/kenlm/maalfrid.norm.sp.arpa.bin b/kenlm/maalfrid.norm.sp.arpa.bin new file mode 100644 index 0000000000000000000000000000000000000000..8466cd868b6d9662b46aa16dd60a2212f6277fcf --- /dev/null +++ b/kenlm/maalfrid.norm.sp.arpa.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:05f2b5ee9ad6f953bcfb6ed31584706225d8390275fb78b4848b1dd697fbedb6 +size 5938309481 diff --git a/kenlm/newspapers.norm.arpa b/kenlm/newspapers.norm.arpa new file mode 100644 index 0000000000000000000000000000000000000000..77816213c75b7daad86dc45dd4757f0337158b15 --- /dev/null +++ b/kenlm/newspapers.norm.arpa @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d87f6044f5f3b58b94c23e556ef2fef1f2f5cee4f27f0bd81293e6d6bb2579ff +size 2151432996 diff --git a/kenlm/newspapers.norm.arpa.bin b/kenlm/newspapers.norm.arpa.bin new file mode 100644 index 0000000000000000000000000000000000000000..6f8e48a63b793fe813d534afb009f8d0555c727f --- /dev/null +++ b/kenlm/newspapers.norm.arpa.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e63eef20ccd2a4977f1cd314e3d42ec3c04fe68ec5fb3a5ff37e2af64d966c9a +size 1095860943 diff --git a/kenlm/newspapers.norm.sp.arpa b/kenlm/newspapers.norm.sp.arpa new file mode 100644 index 0000000000000000000000000000000000000000..f339371a6ae43ff8c677a29dd600dfdf239993a7 --- /dev/null +++ b/kenlm/newspapers.norm.sp.arpa @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65bb2007e807efcb548f51c18b9c7791606bd11807e292d250051efd4529ee7b +size 2660277943 diff --git a/kenlm/newspapers.norm.sp.arpa.bin b/kenlm/newspapers.norm.sp.arpa.bin new file mode 100644 index 0000000000000000000000000000000000000000..daa71a8aba09e9f5178ea5222637767f0ff2bb8d --- /dev/null +++ b/kenlm/newspapers.norm.sp.arpa.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:50a79b25fc03c34278dc2cbb0b91119dfe3ba3d1e6c671b9a81127edf3746a67 +size 1217336194 diff --git a/kenlm/wikipedia/.keep b/kenlm/wikipedia/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/normalization.py b/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..7f7c16432a0a48a448712328cfdc791e9b32e036 --- /dev/null +++ b/normalization.py @@ -0,0 +1,154 @@ +import argparse +import unicodedata +import re +from tqdm import tqdm + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import re +import unicodedata + +PUNCTS = '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~«»' +UNICODE_PUNCT = { + ",": ",", + "。": ".", + "、": ",", + "„": '"', + "”": '"', + "“": '"', + "«": '"', + "»": '"', + "1": '"', + "」": '"', + "「": '"', + "《": '"', + "》": '"', + "´": "'", + "∶": ":", + ":": ":", + "?": "?", + "!": "!", + "(": "(", + ")": ")", + ";": ";", + "–": "-", + "—": " - ", + ".": ". ", + "~": "~", + "’": "'", + "…": "...", + "━": "-", + "〈": "<", + "〉": ">", + "【": "[", + "】": "]", + "%": "%", + "►": "-", + "■": " ", # added for Mimir +} + +UNICODE_PUNCT_RE = re.compile(f"[{''.join(UNICODE_PUNCT.keys())}]") + + +def replace_unicode_punct(text: str) -> str: + return "".join(UNICODE_PUNCT.get(c, c) for c in text) + + +def remove_unicode_punct(text: str) -> str: + """More aggressive version of replace_unicode_punct but also faster.""" + return UNICODE_PUNCT_RE.sub("", text) + + +def strip_accents(line: str) -> str: + """Strips accents from a piece of text.""" + nfd = unicodedata.normalize("NFD", line) + output = [c for c in nfd if unicodedata.category(c) != "Mn"] + if len(output) == line: + return line + return "".join(output) + + +# Build a regex matching all control characters. +NON_PRINTING_CHARS_RE = re.compile( + f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]" +) +DIGIT_RE = re.compile(r"\d") +PUNCT_OR_NON_PRINTING_CHARS_RE = re.compile( + (UNICODE_PUNCT_RE.pattern + NON_PRINTING_CHARS_RE.pattern).replace("][", "") +) + + +def remove_non_printing_char(text: str) -> str: + return NON_PRINTING_CHARS_RE.sub("", text) + + +def normalize(line: str, accent=True, case=True, numbers=True, punct=1) -> str: + line = line.strip() + if not line: + return line + if case: + line = line.lower() + if accent: + line = strip_accents(line) + if numbers: + line = DIGIT_RE.sub("0", line) + if punct == 1: + line = replace_unicode_punct(line) + elif punct == 2: + line = remove_unicode_punct(line) + line = remove_non_printing_char(line) + return line + + +def slow_normalize_for_dedup(line: str) -> str: + return normalize(line, accent=False, case=True, numbers=True, punct=2) + + +def normalize_for_dedup(line: str) -> str: + line = line.strip() + if not line: + return line + # case + line = line.lower() + # numbers + line = DIGIT_RE.sub("0", line) + line = PUNCT_OR_NON_PRINTING_CHARS_RE.sub("", line) + return line + +## START OF MIMIR CODE +def normalize_text(line): + normalized_line = unicodedata.normalize('NFKC', line).lower() + + # Add a trailing dot if the line does not end with a punctuation mark + normalized_line = normalized_line.rstrip() + if normalized_line and normalized_line[-1] not in PUNCTS: + normalized_line += '.' + + # Replace newline characters with spaces (if any remain) + # normalized_line = re.sub(r'\r\n|\r|\n', ' ', normalized_line) + normalized_line = normalize(normalized_line, accent=False, case=True, numbers=True, punct=1) + return normalized_line + + +def normalize_file(input_file, output_file, cutoff=None): + with (open(output_file, 'w', encoding='utf-8') as f, + open(input_file, 'r', encoding='utf-8') as lines): + for line_count, line in tqdm(enumerate(lines), desc="Processing"): + f.write(normalize_text(line) + "\n") + if cutoff and line_count >= cutoff: + break + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Normalize text file line by line, ensure trailing punctuation, replace newlines with spaces, and show progress.') + parser.add_argument('input_file', type=str, help='Input file path') + parser.add_argument('output_file', type=str, help='Output file path') + parser.add_argument('--cutoff', required=False, type=int, help='Max number of lines to process') + + args = parser.parse_args() + + normalize_file(args.input_file, args.output_file, args.cutoff) diff --git a/notebooks/gaussian_sampling.ipynb b/notebooks/gaussian_sampling.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..f226fc0b1c2a34a9fc3392298b23e99555b55f39 --- /dev/null +++ b/notebooks/gaussian_sampling.ipynb @@ -0,0 +1,2568 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8a3cfa45", + "metadata": {}, + "source": [ + "# Perplexity-based subsampling of a dataset\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "190664d7", + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "import lzma\n", + "import tarfile\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "import matplotlib as mpl\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "\n", + "#plt.style.use('ggplot')\n", + "#plt.style.use('bmh')\n", + "plt.style.use('fivethirtyeight')\n", + "mpl.rcParams['figure.figsize'] = (14,8)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "931db8d9", + "metadata": {}, + "outputs": [], + "source": [ + "plt.rcParams.update({'font.size': 12})\n", + "\n", + "SMALL_SIZE = 16\n", + "MEDIUM_SIZE = 18\n", + "BIGGER_SIZE = 20\n", + "\n", + "plt.rc('font', size=SMALL_SIZE) # controls default text sizes\n", + "plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title\n", + "plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels\n", + "plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels\n", + "plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels\n", + "plt.rc('legend', fontsize=MEDIUM_SIZE) # legend fontsize\n", + "plt.rc('figure', titlesize=BIGGER_SIZE)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "33cfca52", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import List, Tuple" + ] + }, + { + "cell_type": "markdown", + "id": "44b29018", + "metadata": {}, + "source": [ + "# TL;DR\n", + "\n", + "* **Objetive**: we have a dataset with an arbitrary perplexity distribution. We want to subsample that dataset in a way that we \n", + " - achieve a predefined sampling ratio R \n", + " - increase the share of the dataset having central perplexity distributions.\n", + " We define \"central\" as the perplexities in the two middle quartiles of the original distribution, i.e. interval $[p_{25}$, $p_{75}$]. \n", + " - In concrete terms, given the perplexity values at these quartiles, $X_a$ and $X_b$, we will want to modify the share of those regions from 25% to other values $p_a$ and $p_b$\n", + " \n", + "* **Method**:\n", + " - compute $X_a$ and $X_b$, the perplexity values for $p_{25}$ and $p_{75}$\n", + " - define an initial Gaussian weighting curve as a Gaussian PDF having its $p_{25}$ and $p_{75}$ values in the same $X_a$ and $X_b$ positions as the computed ones\n", + " - compute a histogram of the perplexities\n", + " - use the histogram + initial Gaussian weights to estimate the sampling ratio that would result, and extract from it the normalization factor needed to achieve R\n", + " - modify the paremeters of the initial Gaussian curve by minimizing the error on the desired probabilities $p_a$ and $p_b$\n", + " - subsample the dataset by comparing the perplexity of each sample against the modified normalized Gaussian curve to estimate the probability of retaining it\n", + " \n", + "A final class that implements this procedure is defined in the [PerplexitySubsampler](../subsampler.py) file and it is used in another, [self-contained notebook](gaussian_subsampling.ipynb). The **Development** section in this notebook details the process step by step." + ] + }, + { + "cell_type": "markdown", + "id": "8c9d6270", + "metadata": {}, + "source": [ + "# Development" + ] + }, + { + "cell_type": "markdown", + "id": "79cff218", + "metadata": {}, + "source": [ + "## 1. Data loading \n", + "\n", + "We start by loading the computed perplexity values of the dataset (in this case computed over a 50M random sample)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "visible-acceptance", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "50000000" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = []\n", + "with open(\"../scores/culturax_da.jsonl\") as f:\n", + " for line in f:\n", + " data.append(json.loads(line)[\"perplexities\"][\"wikipedia_pp\"])\n", + "\n", + "data = np.array(data)\n", + "len(data)" + ] + }, + { + "cell_type": "markdown", + "id": "998a6d64", + "metadata": {}, + "source": [ + "Compute quartiles" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "88d8046a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([542497.86733512, 679167.90234057, 998401.07723076])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "qr = np.quantile(data, [0.25, 0.50, 0.75])\n", + "qr" + ] + }, + { + "cell_type": "markdown", + "id": "62c3e90d", + "metadata": {}, + "source": [ + "Plot the distribution, together with its quartiles" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "407c631c", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "ax.hist(data, bins=1000, range=[0, qr[2]*3]);\n", + "#ax.set_title(\"perplexity for a random sample of mC4-es (P95 of 44M values)\");\n", + "#ax.get_yticklabels().set_fontsize(9)\n", + "for q in qr:\n", + " ax.axvline(q, c='r', lw=1)" + ] + }, + { + "cell_type": "markdown", + "id": "b6b9ffa6", + "metadata": {}, + "source": [ + "## 2. Data subsampling procedure\n", + "\n", + "### 2.1 Overall objective\n" + ] + }, + { + "cell_type": "markdown", + "id": "contained-archives", + "metadata": {}, + "source": [ + "We define three regions in the perplexity distribution:\n", + "1. Low perplexity: the region below the $p_{25}$ percentile\n", + "2. Mid perplexity: the central region between the $p_{25}$ percentile and the $p_{75}$ percentile\n", + "3. High perplexity: the region beyond the $p_{75}$ percentile\n", + "\n", + "The objective is to reshape the dataset so that when subsampled we transfer probability mass from regions [1] and [3] to region [2]\n" + ] + }, + { + "cell_type": "markdown", + "id": "reduced-accreditation", + "metadata": {}, + "source": [ + "We use then two points in the perplexity distribution:\n", + "* $X_a$ is the perplexity value for the $p_{25}$ percentile\n", + "* $X_b$ is the perplexity value for the $p_{75}$ percentile" + ] + }, + { + "cell_type": "markdown", + "id": "decreased-premises", + "metadata": {}, + "source": [ + "### 2.2 Gaussian weighting\n", + "\n", + "With this procedure, we will use a Gaussian curve to extract the weights for subsampling" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "third-backing", + "metadata": {}, + "outputs": [], + "source": [ + "from scipy.stats import norm, uniform" + ] + }, + { + "cell_type": "markdown", + "id": "reported-yemen", + "metadata": {}, + "source": [ + "We now design a normal distribution having a probability distribution so that \n", + "* a certain probability mass (which will be less than the original 25%) is below $X_a$\n", + "* a certain probability mass (which will be less than the original 25%) is beyond $X_b$\n", + "\n", + "By moving probability away from the original $p_{25}$ (i.e. $X_a$) and $p_{75}$ (i.e. $X_b$) perplexities, we are going to achieve our aim of reweighting the dataset decreasing the amount of low and high perplexities" + ] + }, + { + "cell_type": "markdown", + "id": "optional-postage", + "metadata": {}, + "source": [ + "#### parameters of the initial Gaussian curve" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "homeless-garbage", + "metadata": {}, + "outputs": [], + "source": [ + "pa = 0.15 # probability fraction we will want below Xa -- should be less than 0.25\n", + "pb = 0.10 # probability fraction over Xb -- should be less than 0.25" + ] + }, + { + "cell_type": "markdown", + "id": "healthy-species", + "metadata": {}, + "source": [ + "The way of computing the desired normal distribution is by inserting those two values into the Gaussian CDF formula, and deducing from the two equations the values of the Gaussian parameters $\\mu$ and $\\sigma$\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + " \n", + "$ \\phi ( \\frac{X_{a} − \\mu}{\\sigma} ) = p_a $\n", + "\n", + "\n", + "\n", + "$\\Rightarrow$\n", + "\n", + "\n", + "\n", + "$X_{a} − \\mu = \\sigma \\cdot \\phi^{-1}(p_a)$\n", + "\n", + "
\n", + "\n", + "$ \\phi ( \\frac{X_{b} − \\mu}{\\sigma} ) = 1 - p_b $\n", + "\n", + "\n", + " \n", + "$\\Rightarrow$\n", + "\n", + "\n", + " \n", + "$X_{b} − \\mu = \\sigma \\cdot \\phi^{-1}(1 - p_b)$\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "visible-purse", + "metadata": {}, + "outputs": [], + "source": [ + "# Standard deviation\n", + "sdev = (qr[0] - qr[2]) / (norm.ppf(pa) - norm.ppf(1-pb))\n", + "\n", + "# Mean\n", + "mean = qr[0] - norm.ppf(pa)*sdev" + ] + }, + { + "cell_type": "markdown", + "id": "stylish-oregon", + "metadata": {}, + "source": [ + "Let's plot the CDF for the normal distribution we have created" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "alternative-addiction", + "metadata": {}, + "outputs": [], + "source": [ + "x = np.linspace(0, qr[2]+qr[0], 5000)\n", + "y = norm.cdf(x, loc=mean, scale=sdev)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "suspended-worst", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "ax.plot(x, y);\n", + "for q, t, s in zip(qr, [\"$X_a$\", \"mean\", \"$X_b$\"], [\"-\", \":\", \"-\"]):\n", + " ax.axvline(q, c='r', lw=1, ls=s)\n", + " ax.text(q, -0.01, t, color=\"r\")" + ] + }, + { + "cell_type": "markdown", + "id": "ambient-orientation", + "metadata": {}, + "source": [ + "We can verify that effectively the aggregated probabilities for our extreme intervals $X_a$ and $X_b$ are as defined by $p_a$ and $p_b$" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "unable-shift", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(0.1499999999999999, 0.8999999999999999)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Check the probabilities for the regions delimited by Xa & Xb\n", + "norm.cdf(qr[0], loc=mean, scale=sdev), norm.cdf(qr[2], loc=mean, scale=sdev)" + ] + }, + { + "cell_type": "markdown", + "id": "congressional-heating", + "metadata": {}, + "source": [ + "And its PDF is as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "f00fb4c7", + "metadata": {}, + "outputs": [], + "source": [ + "def plot_quartiles(ax: plt.Axes):\n", + " for q, s in zip(qr, [\"-\", \":\", \"-\"]):\n", + " ax.axvline(q, c='r', lw=1, ls=s)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "quality-sister", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "gauss_pdf = norm.pdf(x, loc=mean, scale=sdev)\n", + "\n", + "fig, ax = plt.subplots()\n", + "ax.plot(x, gauss_pdf);\n", + "plot_quartiles(ax)" + ] + }, + { + "cell_type": "markdown", + "id": "possible-handy", + "metadata": {}, + "source": [ + "As it should be, the area under the PDF curve is 1 (since the total probability mass must sum to 1)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "green-student", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.9998996644500662" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sum(gauss_pdf)*(x[1]-x[0])" + ] + }, + { + "cell_type": "markdown", + "id": "cefcfe07", + "metadata": {}, + "source": [ + "And the maximum value of the Gaussian curve is at the mean of the PDF" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "37b3500f", + "metadata": {}, + "outputs": [], + "source": [ + "pdf_max = np.max(gauss_pdf)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "postal-connecticut", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(2.0283737486698773e-06, 2.0283739702357006e-06)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pdf_max, norm.pdf(mean, loc=mean, scale=sdev)" + ] + }, + { + "cell_type": "markdown", + "id": "purple-aviation", + "metadata": {}, + "source": [ + "In our initial try, the Gaussian curve that we will use for weighted resampling is the PDF of the gaussian distribution just computed, but normalized so that its maximum (the mean of the Gaussian) equals 1. That means that items with that perplexity will be sampled with probability 1 (i.e will _not_ be subsampled), and the weight will decrease gradually as we get away from that mean" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "peaceful-default", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "ax.plot(x, gauss_pdf/pdf_max);\n", + "plot_quartiles(ax)" + ] + }, + { + "cell_type": "markdown", + "id": "registered-clothing", + "metadata": {}, + "source": [ + "## 3. Apply subsampling to uniform data\n", + "\n", + "To test the process, we are going to apply it to data with uniform probability. We'll use as working interval (0, $X_{a} + X_{b}$). This interval is chosen so that the uniform data has originally a 25% probability mass below $X_a$ and a 25% probability mass above $X_b$, i.e. just as the original dataset did" + ] + }, + { + "cell_type": "markdown", + "id": "fdbfef47", + "metadata": {}, + "source": [ + "### 3.1 Generate uniform data" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "preceding-contest", + "metadata": {}, + "outputs": [], + "source": [ + "interval = [0, qr[2]+qr[0]]\n", + "width = interval[1] - interval[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "legitimate-crown", + "metadata": {}, + "outputs": [], + "source": [ + "# we generate random data on the interval \n", + "data_unif = uniform.rvs(*interval, 300000)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "indie-wheel", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Let's plot the distribution for the test data. Should be flat\n", + "fig, ax = plt.subplots()\n", + "ax.hist(data_unif, bins=300);\n", + "plot_quartiles(ax)" + ] + }, + { + "cell_type": "markdown", + "id": "99fbe6ad", + "metadata": {}, + "source": [ + "### 3.2 subsample\n", + "We define now the subsampling function with gaussian weighting" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "ready-seventh", + "metadata": {}, + "outputs": [], + "source": [ + "def subsample_gauss(data: np.ndarray, mean: float, sdev: float, norm_factor: float) -> np.ndarray:\n", + " \"\"\"\n", + " Vectorized subsampling: process the whole dataset\n", + " \"\"\"\n", + " # Create the gaussian weight for each data point\n", + " p = norm.pdf(data, loc=mean, scale=sdev)/norm_factor\n", + " #print(p)\n", + " # Subsample data with probability according to the weight\n", + " return data[ uniform.rvs(size=len(p)) < p ]" + ] + }, + { + "cell_type": "markdown", + "id": "hidden-shelf", + "metadata": {}, + "source": [ + "After resampling, at each perplexity value, the probability of retaining a sample will be the value of the normalized gaussian PDF curve.\n", + "\n", + "This also means that the overall sample ratio will be:\n", + "\n", + "
\n", + "
\n", + "\n", + "$$\\text{R} = \\int{p(v) \\cdot w(v) dv} = \n", + "\\int{ p(v) \\cdot \\frac{1}{\\text{pdf}_{max}} \\text{pdf}(v) dv } = \n", + "$$\n", + "\n", + "
\n", + "\n", + "$$ = { \\frac{1}{W} \\cdot \\frac{1}{\\text{pdf}_{max}} \\text{pdf}(v) dv } = \n", + "\\frac{1}{W\\cdot\\text{pdf}_{max}}\\int{\\text{pdf}(v) dv} =\n", + "\\frac{1}{W\\cdot\\text{pdf}_{max}}$$\n", + "\n", + "
\n", + " \n", + "where $W$ is the interval width, and we use the value of the uniform probability ($1/W$) and the fact that the integral of the Gaussian distribution over the interval sums approximately 1 (given that the interval is big enough to encompass most of the Gaussian definition interval)\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "russian-sodium", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.31994686644840487" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ratio = 1/(width*pdf_max)\n", + "ratio" + ] + }, + { + "cell_type": "markdown", + "id": "fdcae693", + "metadata": {}, + "source": [ + "Let's do the resampling:" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "random-champion", + "metadata": {}, + "outputs": [], + "source": [ + "data_unif_sub = subsample_gauss(data_unif, mean, sdev, pdf_max)" + ] + }, + { + "cell_type": "markdown", + "id": "vietnamese-carry", + "metadata": {}, + "source": [ + "### 3.3 check results" + ] + }, + { + "cell_type": "markdown", + "id": "quarterly-february", + "metadata": {}, + "source": [ + "Check the obtained sampling ratio; it should be aproximately equal to the estimation:" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "specialized-penny", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.3208933333333333" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(data_unif_sub)/len(data_unif)" + ] + }, + { + "cell_type": "markdown", + "id": "thirty-mexican", + "metadata": {}, + "source": [ + "Check the shares of the distribution below $X_a$ and above $X_b$, they should match our objectives" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "intensive-fitness", + "metadata": {}, + "outputs": [], + "source": [ + "def check_regions(data: np.ndarray):\n", + " ra = len(data[data < qr[0]]) / len(data)\n", + " print(\"Probability mass below Pa:\", ra)\n", + " rb = len(data[data > qr[2]]) / len(data)\n", + " print(\"Probability mass above Pb:\", rb)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "advisory-certification", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Probability mass below Pa: 0.14987327045331783\n", + "Probability mass above Pb: 0.1009785183030706\n" + ] + } + ], + "source": [ + "check_regions(data_unif_sub)" + ] + }, + { + "cell_type": "markdown", + "id": "brown-perception", + "metadata": {}, + "source": [ + "We can plot the resulting data distribution" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "painted-madonna", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "ax.hist(data_unif_sub, bins=200);\n", + "plot_quartiles(ax)" + ] + }, + { + "cell_type": "markdown", + "id": "bored-feedback", + "metadata": {}, + "source": [ + "### 3.4 Adjust sampling ratio\n", + "\n", + "The \"natural\" sampling ratio for this Gaussian weighting obtained is, as computed above, of 32%. We might want to achieve a different sampling ratio; to achieve that we can modify the normalization factor so that we lower the Gaussian curve to sample less data (but still keep the relative weights); in this case the peak of the Gaussian would not sample at 100%, but less than that." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "confused-external", + "metadata": {}, + "outputs": [], + "source": [ + "# Let's seek a sampling ratio of 20%\n", + "desired_sampling_ratio = 0.20\n", + "\n", + "# With this desired fraction, we compute the new normalization factor\n", + "unif_norm_factor = 1/(width*desired_sampling_ratio)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "controversial-caribbean", + "metadata": {}, + "outputs": [], + "source": [ + "data_unif_sub2 = subsample_gauss(data_unif, mean, sdev, unif_norm_factor)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "angry-stomach", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.20069666666666666" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Now we have our desired sampling ratio\n", + "len(data_unif_sub2)/len(data_unif)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "transparent-width", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Probability mass below Pa: 0.14976166353867362\n", + "Probability mass above Pb: 0.1011808865784185\n" + ] + } + ], + "source": [ + "# And the probability masses stay as before\n", + "check_regions(data_unif_sub2)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "ethical-resolution", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "ax.hist(data_unif_sub2, bins=200);\n", + "plot_quartiles(ax)" + ] + }, + { + "cell_type": "markdown", + "id": "b55ef486", + "metadata": {}, + "source": [ + "## 4. Subsample the original dataset\n", + "\n", + "### 4.1 Direct approach\n", + "\n", + "We now apply the same procedure to our original (non-uniform) dataset, using the same normalization factor we used for the uniform dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "9d0201e0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 3.64 s, sys: 57.9 s, total: 1min 1s\n", + "Wall time: 1min 1s\n" + ] + } + ], + "source": [ + "%%time\n", + "data_sub = subsample_gauss(data, mean, sdev, unif_norm_factor)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "d2affc12", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.32652108" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Compute the sampling ratio we have achieved\n", + "len(data_sub)/len(data)" + ] + }, + { + "cell_type": "markdown", + "id": "c89e184a", + "metadata": {}, + "source": [ + "We did not actually achieve the desired sampling ratio. Let's take a look at the resulting perplexity distibution, comparing it with the original" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "7906df3c", + "metadata": {}, + "outputs": [], + "source": [ + "def plot_subsample(orig: np.ndarray, sub: np.ndarray, \n", + " name: str = 'Gaussian subsampling'):\n", + " fig, ax = plt.subplots()\n", + " if orig is not None:\n", + " ax.hist(orig, bins=1000, range=[0, qr[2]*3]);\n", + " ax.hist(sub, bins=1000, range=[0, qr[2]*3], color=\"g\");\n", + " if orig is not None:\n", + " ax.legend(['original', 'subsampled'])\n", + " plot_quartiles(ax)\n", + " ax.set_title(\"Perplexity distribution \" + (\"before and after \" if orig is not None else \"for \") + name);" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "435111d2", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_subsample(data, data_sub)" + ] + }, + { + "cell_type": "markdown", + "id": "683c0096", + "metadata": {}, + "source": [ + "### 4.2 Adjust sampling ratio\n", + "\n", + "In order to achieve the desired sampling ratio, we now muct take into account that now the data distribution is not uniform, hence the previous simple computation no longer applies. \n", + "\n", + "Instead, we need to compute the sampling ratio from the full integral. In order to approximate it, we compute an histogram of the data and appliximate the integral through a sum over the histogram bins" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "d7a1f784", + "metadata": {}, + "outputs": [], + "source": [ + "# Compute a histogram for the dataset\n", + "hbins = 1000\n", + "range_max = qr[2]*10\n", + "hcounts, hedges = np.histogram(data, bins=hbins, range=[0, range_max])\n", + "hcounts[-1] += len(data[data>range_max])\n", + "\n", + "hperp = (hedges[:-1] + hedges[1:])/2" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "0cbd8b75", + "metadata": {}, + "outputs": [], + "source": [ + "# Now let's compute the gaussian weighting function over the histogram bins\n", + "gauss_weights = norm.pdf(hperp, loc=mean, scale=sdev)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "66c43147", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.3265025569497173" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# And we estimate the sampling ratio by a sum\n", + "samples = (hcounts*gauss_weights/unif_norm_factor).sum()\n", + "\n", + "samples/hcounts.sum()" + ] + }, + { + "cell_type": "markdown", + "id": "e0b28370", + "metadata": {}, + "source": [ + "This theoretical number is a good approximation to the one obtained from the actual sample. So, we use this procedure in the reverse, to compute the normalization factor we need to achieve the desired sampling ratio" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "d3cf86c1", + "metadata": {}, + "outputs": [], + "source": [ + "adjusted_norm_factor = (hcounts*gauss_weights).sum()/hcounts.sum()/desired_sampling_ratio" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "40684c79", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 3.35 s, sys: 38.4 s, total: 41.8 s\n", + "Wall time: 42 s\n" + ] + } + ], + "source": [ + "%%time\n", + "data_sub_adjusted = subsample_gauss(data, mean, sdev, adjusted_norm_factor)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "85088557", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.19998344" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Obtained ratio\n", + "len(data_sub_adjusted)/len(data)" + ] + }, + { + "cell_type": "markdown", + "id": "3a56780e", + "metadata": {}, + "source": [ + "... the ratio now is close to the desired factor. The resulting distribution is as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "b2e956d2", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_subsample(data, data_sub_adjusted)" + ] + }, + { + "cell_type": "markdown", + "id": "a45660a5", + "metadata": {}, + "source": [ + "The probability masses, though, have not been adjusted to our original objectives:" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "d44c0178", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Probability mass below Pa: 0.18007621031021367\n", + "Probability mass above Pb: 0.0318859401558449\n" + ] + } + ], + "source": [ + "check_regions(data_sub_adjusted)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "6c96ae05", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.09827000000000002" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "abs(pa-0.1801) + abs(pb-0.03183)" + ] + }, + { + "cell_type": "markdown", + "id": "39781e00", + "metadata": {}, + "source": [ + "The transfer of probability mass is, in fact, greater than oour objectives of retaining a 15% and 10%. \n", + "\n", + "This is to be expected, given the non-uniform probability of the data. So the simple global computation we did is no longer valid. If we want to evaluate the actual probability masses in the regions given by $X_a$ and $X_b$ we need to compute the actual integral, or apporximate it by histogram sums" + ] + }, + { + "cell_type": "markdown", + "id": "e3e765f6", + "metadata": {}, + "source": [ + "## 4.3 Modification of the Gaussian weighting curve\n", + "\n", + "If we modify the `mean` and `sdev` Gaussian parameters from the ones estimated for uniform data probability , we can change the probability masses to be transferred:\n", + " * modify sdev changes the results for both regions, transferring more or less data from them\n", + " * while moving the mean changes the relation asymmetrically, giving more weight to the first or to the fourth quartile\n", + "\n", + " \n", + "### 4.3.1 Subsample with different parameter combinations\n", + "\n", + "Let's try a grid of values around the initial values" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "5c3778ee", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_subsample_results(data, m: float, s: float) -> Tuple[float, float, float]:\n", + " # Gaussian weigthing\n", + " gauss_weights = norm.pdf(hperp, loc=m, scale=s)\n", + " adjusted_norm_factor = (hcounts*gauss_weights).sum()/hcounts.sum()/desired_sampling_ratio\n", + " # Subsample\n", + " data_sub = subsample_gauss(data, m, s, adjusted_norm_factor)\n", + " # Compute result metrics\n", + " ra = len(data_sub[data_sub < qr[0]]) / len(data_sub)\n", + " rb = len(data_sub[data_sub > qr[2]]) / len(data_sub)\n", + " return len(data_sub)/len(data), ra, rb" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "a4e7921f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 1min 32s, sys: 14min 21s, total: 15min 53s\n", + "Wall time: 16min\n" + ] + } + ], + "source": [ + "%%time\n", + "df = pd.DataFrame(([f1, f2, mean*f1, sdev*f2, *compute_subsample_results(data, mean*f1, sdev*f2)]\n", + " for f1 in np.arange(0.8, 1.3, 0.1) for f2 in np.arange(0.8, 1.3, 0.1)), \n", + " columns=[\"factor1\", \"factor2\", \"mean\", \"sdev\", \"ratio\", \"pa\", \"pb\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "9778ae79", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
factor1factor2meansdevratiopapb
00.80.8597075.559089157344.6657290.1999870.3273610.001474
10.80.9597075.559089177012.7489460.2000600.3309270.003349
20.81.0597075.559089196680.8321620.2001150.3326190.006112
30.81.1597075.559089216348.9153780.2000930.3327100.009676
40.81.2597075.559089236016.9985940.2000260.3319740.013916
50.90.8671710.003975157344.6657290.1999820.2192790.005382
60.90.9671710.003975177012.7489460.2000270.2382700.009421
70.91.0671710.003975196680.8321620.1999960.2526140.014298
80.91.1671710.003975216348.9153780.2000720.2629470.019859
90.91.2671710.003975236016.9985940.1999870.2705640.025817
101.00.8746344.448861157344.6657290.2001250.1309400.017452
111.00.9746344.448861177012.7489460.1999840.1578200.024555
121.01.0746344.448861196680.8321620.2000450.1801630.031770
131.01.1746344.448861216348.9153780.2000710.1981400.039049
141.01.2746344.448861236016.9985940.1999730.2122290.046073
151.10.8820978.893747157344.6657290.1999790.0675920.049761
161.10.9820978.893747177012.7489460.1999390.0941720.058519
171.11.0820978.893747196680.8321620.2000170.1189650.066225
181.11.1820978.893747216348.9153780.1999600.1405010.073169
191.11.2820978.893747236016.9985940.1999180.1592360.079432
201.20.8895613.338633157344.6657290.1999710.0291370.121411
211.20.9895613.338633177012.7489460.1999140.0494250.125282
221.21.0895613.338633196680.8321620.1999610.0714710.127871
231.21.1895613.338633216348.9153780.1999230.0929870.129851
241.21.2895613.338633236016.9985940.2000010.1132400.131530
\n", + "
" + ], + "text/plain": [ + " factor1 factor2 mean sdev ratio pa \\\n", + "0 0.8 0.8 597075.559089 157344.665729 0.199987 0.327361 \n", + "1 0.8 0.9 597075.559089 177012.748946 0.200060 0.330927 \n", + "2 0.8 1.0 597075.559089 196680.832162 0.200115 0.332619 \n", + "3 0.8 1.1 597075.559089 216348.915378 0.200093 0.332710 \n", + "4 0.8 1.2 597075.559089 236016.998594 0.200026 0.331974 \n", + "5 0.9 0.8 671710.003975 157344.665729 0.199982 0.219279 \n", + "6 0.9 0.9 671710.003975 177012.748946 0.200027 0.238270 \n", + "7 0.9 1.0 671710.003975 196680.832162 0.199996 0.252614 \n", + "8 0.9 1.1 671710.003975 216348.915378 0.200072 0.262947 \n", + "9 0.9 1.2 671710.003975 236016.998594 0.199987 0.270564 \n", + "10 1.0 0.8 746344.448861 157344.665729 0.200125 0.130940 \n", + "11 1.0 0.9 746344.448861 177012.748946 0.199984 0.157820 \n", + "12 1.0 1.0 746344.448861 196680.832162 0.200045 0.180163 \n", + "13 1.0 1.1 746344.448861 216348.915378 0.200071 0.198140 \n", + "14 1.0 1.2 746344.448861 236016.998594 0.199973 0.212229 \n", + "15 1.1 0.8 820978.893747 157344.665729 0.199979 0.067592 \n", + "16 1.1 0.9 820978.893747 177012.748946 0.199939 0.094172 \n", + "17 1.1 1.0 820978.893747 196680.832162 0.200017 0.118965 \n", + "18 1.1 1.1 820978.893747 216348.915378 0.199960 0.140501 \n", + "19 1.1 1.2 820978.893747 236016.998594 0.199918 0.159236 \n", + "20 1.2 0.8 895613.338633 157344.665729 0.199971 0.029137 \n", + "21 1.2 0.9 895613.338633 177012.748946 0.199914 0.049425 \n", + "22 1.2 1.0 895613.338633 196680.832162 0.199961 0.071471 \n", + "23 1.2 1.1 895613.338633 216348.915378 0.199923 0.092987 \n", + "24 1.2 1.2 895613.338633 236016.998594 0.200001 0.113240 \n", + "\n", + " pb \n", + "0 0.001474 \n", + "1 0.003349 \n", + "2 0.006112 \n", + "3 0.009676 \n", + "4 0.013916 \n", + "5 0.005382 \n", + "6 0.009421 \n", + "7 0.014298 \n", + "8 0.019859 \n", + "9 0.025817 \n", + "10 0.017452 \n", + "11 0.024555 \n", + "12 0.031770 \n", + "13 0.039049 \n", + "14 0.046073 \n", + "15 0.049761 \n", + "16 0.058519 \n", + "17 0.066225 \n", + "18 0.073169 \n", + "19 0.079432 \n", + "20 0.121411 \n", + "21 0.125282 \n", + "22 0.127871 \n", + "23 0.129851 \n", + "24 0.131530 " + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df" + ] + }, + { + "cell_type": "markdown", + "id": "3f1e0585", + "metadata": {}, + "source": [ + "As it can be seen, modifying the parameters of the Gaussian curve *does* change the obtained values of $p_a$ and $p_b$" + ] + }, + { + "cell_type": "markdown", + "id": "7d9f2276", + "metadata": {}, + "source": [ + "### 4.3.2 Use an estimation\n", + "\n", + "Extracting complete subsamples to analyze results is, as we have seen, too computationally intensive. Let's do an approximation instead, by estimating the results on a perplexity histogram" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "74d58fb9", + "metadata": {}, + "outputs": [], + "source": [ + "def subsample_gauss_histo(hperp: np.ndarray, hcounts: np.ndarray, \n", + " mean: float, sdev: float, norm_factor: float) -> np.ndarray:\n", + " \"\"\"\n", + " Execute the subsampling on a histogram\n", + " \"\"\"\n", + " # Create the gaussian weight for each data point\n", + " p = norm.pdf(hperp, loc=mean, scale=sdev)/norm_factor\n", + " # Subsample data with probability according to each weight\n", + " return hcounts*p\n", + "\n", + "def compute_percentile_estimation(hedges: np.ndarray, hcounts: np.ndarray, perp_value) -> float:\n", + " \"\"\"\n", + " Estimate the percentile reached by a given perplexity value\n", + " \"\"\"\n", + " v = np.searchsorted(hedges, perp_value, side=\"right\")\n", + " return hcounts[:v-1].sum() + hcounts[v-1]*(perp_value - hedges[v-1])/(hedges[v] - hedges[v-1])\n", + "\n", + "def estimate_subsample_results(m: float, s: float) -> Tuple[float, float, float]:\n", + " # Gaussian weight on perplexity values\n", + " gauss_weights = norm.pdf(hperp, loc=m, scale=s)\n", + " # Normalization factor needed for the desired ratio\n", + " adjusted_norm_factor = (hcounts*gauss_weights).sum()/hcounts.sum()/desired_sampling_ratio\n", + " # Subsample the histogram\n", + " hcounts_sub = subsample_gauss_histo(hperp, hcounts, m, s, adjusted_norm_factor)\n", + " sub_size = hcounts_sub.sum()\n", + " # Compute the results of the subsampling\n", + " ratio = sub_size/hcounts.sum()\n", + " ra = compute_percentile_estimation(hedges, hcounts_sub, qr[0])/sub_size\n", + " rb = compute_percentile_estimation(hedges, hcounts_sub, qr[2])/sub_size\n", + " return ratio, ra, 1-rb" + ] + }, + { + "cell_type": "markdown", + "id": "006c4866", + "metadata": {}, + "source": [ + "Let's repeat the same process, but now with the estimator. It should be much faster" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "1298ea05", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 12.6 ms, sys: 1.81 ms, total: 14.4 ms\n", + "Wall time: 19.8 ms\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "result = ([f1, f2, mean*f1, sdev*f2, *estimate_subsample_results(mean*f1, sdev*f2)]\n", + " for f1 in np.arange(0.8, 1.3, 0.1) for f2 in np.arange(0.8, 1.3, 0.1))\n", + "\n", + "df_est = pd.DataFrame(result, columns=[\"factor1\", \"factor2\", \"mean\", \"sdev\", \"ratio\", \"pa\", \"pb\"])" + ] + }, + { + "cell_type": "markdown", + "id": "ebdceefe", + "metadata": {}, + "source": [ + "We can compare the actual quartile results with the ones we have estimated" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "5c811407", + "metadata": {}, + "outputs": [], + "source": [ + "df['pa_est'] = df_est['pa']\n", + "df['pb_est'] = df_est['pb']\n", + "df['err_pa'] = df.pa - df.pa_est\n", + "df['err_pb'] = df.pb - df.pb_est" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "550a7351", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
factor1factor2meansdevratiopapbpa_estpb_esterr_paerr_pb
00.80.8597075.559089157344.6657290.1999870.3273610.0014740.3273430.0014720.0000180.000002
10.80.9597075.559089177012.7489460.2000600.3309270.0033490.3310550.003324-0.0001280.000025
20.81.0597075.559089196680.8321620.2001150.3326190.0061120.3325720.0060820.0000470.000029
30.81.1597075.559089216348.9153780.2000930.3327100.0096760.3326950.0096740.0000140.000003
40.81.2597075.559089236016.9985940.2000260.3319740.0139160.3319330.0139560.000041-0.000040
50.90.8671710.003975157344.6657290.1999820.2192790.0053820.2194740.005363-0.0001950.000019
60.90.9671710.003975177012.7489460.2000270.2382700.0094210.2384660.009394-0.0001970.000027
70.91.0671710.003975196680.8321620.1999960.2526140.0142980.2524920.0143010.000123-0.000003
80.91.1671710.003975216348.9153780.2000720.2629470.0198590.2628600.0198160.0000870.000043
90.91.2671710.003975236016.9985940.1999870.2705640.0258170.2705350.0257110.0000290.000106
101.00.8746344.448861157344.6657290.2001250.1309400.0174520.1310280.017454-0.000088-0.000002
111.00.9746344.448861177012.7489460.1999840.1578200.0245550.1579820.024564-0.000162-0.000010
121.01.0746344.448861196680.8321620.2000450.1801630.0317700.1801180.0318180.000045-0.000048
131.01.1746344.448861216348.9153780.2000710.1981400.0390490.1980260.0389970.0001140.000052
141.01.2746344.448861236016.9985940.1999730.2122290.0460730.2124350.045980-0.0002060.000092
151.10.8820978.893747157344.6657290.1999790.0675920.0497610.0676930.049718-0.0001020.000043
161.10.9820978.893747177012.7489460.1999390.0941720.0585190.0943240.058571-0.000152-0.000052
171.11.0820978.893747196680.8321620.2000170.1189650.0662250.1189720.066312-0.000007-0.000087
181.11.1820978.893747216348.9153780.1999600.1405010.0731690.1407230.073189-0.000221-0.000020
191.11.2820978.893747236016.9985940.1999180.1592360.0794320.1594270.079397-0.0001910.000034
201.20.8895613.338633157344.6657290.1999710.0291370.1214110.0293310.121336-0.0001940.000075
211.20.9895613.338633177012.7489460.1999140.0494250.1252820.0495780.125308-0.000154-0.000026
221.21.0895613.338633196680.8321620.1999610.0714710.1278710.0715050.127902-0.000034-0.000030
231.21.1895613.338633216348.9153780.1999230.0929870.1298510.0931080.129839-0.0001210.000013
241.21.2895613.338633236016.9985940.2000010.1132400.1315300.1132300.1315090.0000100.000021
\n", + "
" + ], + "text/plain": [ + " factor1 factor2 mean sdev ratio pa \\\n", + "0 0.8 0.8 597075.559089 157344.665729 0.199987 0.327361 \n", + "1 0.8 0.9 597075.559089 177012.748946 0.200060 0.330927 \n", + "2 0.8 1.0 597075.559089 196680.832162 0.200115 0.332619 \n", + "3 0.8 1.1 597075.559089 216348.915378 0.200093 0.332710 \n", + "4 0.8 1.2 597075.559089 236016.998594 0.200026 0.331974 \n", + "5 0.9 0.8 671710.003975 157344.665729 0.199982 0.219279 \n", + "6 0.9 0.9 671710.003975 177012.748946 0.200027 0.238270 \n", + "7 0.9 1.0 671710.003975 196680.832162 0.199996 0.252614 \n", + "8 0.9 1.1 671710.003975 216348.915378 0.200072 0.262947 \n", + "9 0.9 1.2 671710.003975 236016.998594 0.199987 0.270564 \n", + "10 1.0 0.8 746344.448861 157344.665729 0.200125 0.130940 \n", + "11 1.0 0.9 746344.448861 177012.748946 0.199984 0.157820 \n", + "12 1.0 1.0 746344.448861 196680.832162 0.200045 0.180163 \n", + "13 1.0 1.1 746344.448861 216348.915378 0.200071 0.198140 \n", + "14 1.0 1.2 746344.448861 236016.998594 0.199973 0.212229 \n", + "15 1.1 0.8 820978.893747 157344.665729 0.199979 0.067592 \n", + "16 1.1 0.9 820978.893747 177012.748946 0.199939 0.094172 \n", + "17 1.1 1.0 820978.893747 196680.832162 0.200017 0.118965 \n", + "18 1.1 1.1 820978.893747 216348.915378 0.199960 0.140501 \n", + "19 1.1 1.2 820978.893747 236016.998594 0.199918 0.159236 \n", + "20 1.2 0.8 895613.338633 157344.665729 0.199971 0.029137 \n", + "21 1.2 0.9 895613.338633 177012.748946 0.199914 0.049425 \n", + "22 1.2 1.0 895613.338633 196680.832162 0.199961 0.071471 \n", + "23 1.2 1.1 895613.338633 216348.915378 0.199923 0.092987 \n", + "24 1.2 1.2 895613.338633 236016.998594 0.200001 0.113240 \n", + "\n", + " pb pa_est pb_est err_pa err_pb \n", + "0 0.001474 0.327343 0.001472 0.000018 0.000002 \n", + "1 0.003349 0.331055 0.003324 -0.000128 0.000025 \n", + "2 0.006112 0.332572 0.006082 0.000047 0.000029 \n", + "3 0.009676 0.332695 0.009674 0.000014 0.000003 \n", + "4 0.013916 0.331933 0.013956 0.000041 -0.000040 \n", + "5 0.005382 0.219474 0.005363 -0.000195 0.000019 \n", + "6 0.009421 0.238466 0.009394 -0.000197 0.000027 \n", + "7 0.014298 0.252492 0.014301 0.000123 -0.000003 \n", + "8 0.019859 0.262860 0.019816 0.000087 0.000043 \n", + "9 0.025817 0.270535 0.025711 0.000029 0.000106 \n", + "10 0.017452 0.131028 0.017454 -0.000088 -0.000002 \n", + "11 0.024555 0.157982 0.024564 -0.000162 -0.000010 \n", + "12 0.031770 0.180118 0.031818 0.000045 -0.000048 \n", + "13 0.039049 0.198026 0.038997 0.000114 0.000052 \n", + "14 0.046073 0.212435 0.045980 -0.000206 0.000092 \n", + "15 0.049761 0.067693 0.049718 -0.000102 0.000043 \n", + "16 0.058519 0.094324 0.058571 -0.000152 -0.000052 \n", + "17 0.066225 0.118972 0.066312 -0.000007 -0.000087 \n", + "18 0.073169 0.140723 0.073189 -0.000221 -0.000020 \n", + "19 0.079432 0.159427 0.079397 -0.000191 0.000034 \n", + "20 0.121411 0.029331 0.121336 -0.000194 0.000075 \n", + "21 0.125282 0.049578 0.125308 -0.000154 -0.000026 \n", + "22 0.127871 0.071505 0.127902 -0.000034 -0.000030 \n", + "23 0.129851 0.093108 0.129839 -0.000121 0.000013 \n", + "24 0.131530 0.113230 0.131509 0.000010 0.000021 " + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df" + ] + }, + { + "cell_type": "markdown", + "id": "cfe2ca39", + "metadata": {}, + "source": [ + "As we can see, the estimation produces the same results as the full computation. So we can use it safely" + ] + }, + { + "cell_type": "markdown", + "id": "bb325b33", + "metadata": {}, + "source": [ + "### 4.3.3 Optimization\n", + "\n", + "Now that we can produce a fast estimation of the results, we can try to optimize the arguments for the Gaussian curve by minimizing the error on the desired probability masses" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "bec2f2c6", + "metadata": {}, + "outputs": [], + "source": [ + "import scipy as sp" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "75ba883b", + "metadata": {}, + "outputs": [], + "source": [ + "def error(point, pa, pb):\n", + " \"\"\"\n", + " Compute the estimation error to minimize\n", + " \"\"\"\n", + " _, actual_pa, actual_pb = estimate_subsample_results(point[0], point[1])\n", + " return abs(pa-actual_pa) + abs(pb-actual_pb)" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "4c7f409a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Optimization terminated successfully.\n", + " Current function value: 0.000000\n", + " Iterations: 104\n", + " Function evaluations: 199\n" + ] + } + ], + "source": [ + "initial = np.array([mean, sdev])\n", + "r = sp.optimize.minimize(error, initial, args=(pa, pb), \n", + " method='nelder-mead', options={'xatol': 1e-8, 'disp': True})" + ] + }, + { + "cell_type": "markdown", + "id": "cfc4442c", + "metadata": {}, + "source": [ + "Minimization was successful, so we can use now the obtained `mean` and `sdev` values" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "2372db3b", + "metadata": {}, + "outputs": [], + "source": [ + "def adjusted_gaussian_subsample(data: np.ndarray, mean: float, sdev: float, \n", + " ratio: float, hbins: int = 1000) -> np.ndarray:\n", + " \"\"\"\n", + " Subsample a dataset ensuring a given sampling ratio\n", + " \"\"\"\n", + " # Compute a histogram for this dataset\n", + " hcounts, hedges = np.histogram(data, bins=hbins, range=[0, qr[2]*10])\n", + " hperp = (hedges[:-1] + hedges[1:])/2\n", + " # Compute the weighting function over the histogram values\n", + " gauss_weights = norm.pdf(hperp, loc=mean, scale=sdev)\n", + " # Normalize the weights to achive the desired sampling ratio\n", + " adjusted_norm_factor = (hcounts*gauss_weights).sum()/hcounts.sum()/ratio\n", + " # Subsample\n", + " return subsample_gauss(data, mean, sdev, adjusted_norm_factor)" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "6813dc15", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 4.43 s, sys: 27.8 s, total: 32.2 s\n", + "Wall time: 32.4 s\n" + ] + } + ], + "source": [ + "%%time\n", + "data_sub_adj2 = adjusted_gaussian_subsample(data, r.x[0], r.x[1], desired_sampling_ratio)" + ] + }, + { + "cell_type": "markdown", + "id": "4d184d7e", + "metadata": {}, + "source": [ + "Let's check we *did* obtain the desired results and plot the distributions" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "id": "0d1ed557", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ratio: 0.19997422\n", + "Probability mass below Pa: 0.14992142487166596\n", + "Probability mass above Pb: 0.10007179925492396\n" + ] + } + ], + "source": [ + "print(\"Ratio:\", len(data_sub_adj2)/len(data))\n", + "check_regions(data_sub_adj2)" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "id": "3d153b4a", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_subsample(data, data_sub_adj2)" + ] + }, + { + "cell_type": "markdown", + "id": "af8c5b32", + "metadata": {}, + "source": [ + "To take a closer look, we show only the distribution for the subsampled data" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "id": "49a77f85", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_subsample(None, data_sub_adj2)" + ] + }, + { + "cell_type": "markdown", + "id": "d287335e", + "metadata": {}, + "source": [ + "## 5. Annex: piecewise sampling\n", + "\n", + "A different, and more direct approach, is to subsample with different ratios on the three regions" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "id": "0d57d187", + "metadata": {}, + "outputs": [], + "source": [ + "def subsample_frac(data: np.ndarray, frac: float) -> np.ndarray:\n", + " return data[uniform.rvs(size=len(data)) < frac]\n", + "\n", + "def subsample_byregion(data: np.ndarray, frac: float, q1: float, q2: float) -> np.ndarray:\n", + " data1 = subsample_frac(data[data=qr[0]) & (data<=qr[2])], (1-q1-q2)/0.5*frac)\n", + " data3 = subsample_frac(data[data>qr[2]], q2/0.25*frac)\n", + " return np.hstack([data1, data2, data3])" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "id": "1cb0056e", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 2.15 s, sys: 5.13 s, total: 7.28 s\n", + "Wall time: 7.29 s\n" + ] + } + ], + "source": [ + "%%time\n", + "data_piece = subsample_byregion(data, 0.20, 0.15, 0.10)" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "3cd4b533", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0.20003226" + ] + }, + "execution_count": 64, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(data_piece)/len(data)" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "id": "1ec4f255", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Probability mass below Pa: 0.1500345994191137\n", + "Probability mass above Pb: 0.10010635284528606\n" + ] + } + ], + "source": [ + "check_regions(data_piece)" + ] + }, + { + "cell_type": "markdown", + "id": "df483672", + "metadata": {}, + "source": [ + "This method obtains also the desired values. However, the resulting distribution is quite different:" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "id": "f2231d55", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_subsample(data, data_piece, 'piecewise subsampling')" + ] + }, + { + "cell_type": "markdown", + "id": "4692bdb9", + "metadata": {}, + "source": [ + "The main differences with Gaussian subsampling are:\n", + " * There are discontinuities at the region boundaries, so that perplexity values at both sides of the boundary get quite different sampling ratios. We lose the gradual transitions of Gaussian sampling\n", + " * Inside each region the pattern is roughly equivalent to the original shape, though obviously subsampled. This means that we keep the long tail behaviour of the original distribution. While for Gaussian weighting the shape of the weighting function cuts down the long tails, concentrating the distribution near the quartile boundaries" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/gaussian_subsampling.ipynb b/notebooks/gaussian_subsampling.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..d7752500b3dea97b9da3588cba0b6d9262a5474c --- /dev/null +++ b/notebooks/gaussian_subsampling.ipynb @@ -0,0 +1,809 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "aef779f2-7e26-45d7-8ec2-593a571344aa", + "metadata": {}, + "source": [ + "# Perplexity-based subsampling of a dataset\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "44566054-ef4c-4dee-afac-1f5affd8e86f", + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "import lzma\n", + "import json\n", + "import tarfile\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "import matplotlib as mpl\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "\n", + "#plt.style.use('ggplot')\n", + "#plt.style.use('bmh')\n", + "plt.style.use('fivethirtyeight')\n", + "mpl.rcParams['figure.figsize'] = (14,8)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "1bf78228-6bf8-4148-9de3-9f4eeff3ccc9", + "metadata": {}, + "outputs": [], + "source": [ + "plt.rcParams.update({'font.size': 12})\n", + "\n", + "SMALL_SIZE = 16\n", + "MEDIUM_SIZE = 18\n", + "BIGGER_SIZE = 20\n", + "\n", + "plt.rc('font', size=SMALL_SIZE) # controls default text sizes\n", + "plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title\n", + "plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels\n", + "plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels\n", + "plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels\n", + "plt.rc('legend', fontsize=MEDIUM_SIZE) # legend fontsize\n", + "plt.rc('figure', titlesize=BIGGER_SIZE)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "00cc7e68-df5e-48ff-9897-7bc8646db1eb", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import List, Tuple" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "9fe94ce2-53fc-4271-b664-e7b0ed5dbc17", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1000000" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = []\n", + "with open(\"../scores/culturax_da.jsonl\") as f:\n", + " for line in f:\n", + " data.append(json.loads(line)[\"perplexities\"][\"wikipedia_pp\"])\n", + "\n", + "data = np.array(data)\n", + "len(data)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "79e4d349-016f-4fac-95c8-be11ea11585c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([397.975, 503.3 , 648.2 ])" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "qr = np.quantile(data, [0.25, 0.50, 0.75])\n", + "qr\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "0c38e201-dcd6-4326-843b-52200c974312", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "ax.hist(data, bins=1000, range=[0, qr[2]*3]);\n", + "#ax.set_title(\"perplexity for a random sample of mC4-es (P95 of 44M values)\");\n", + "#ax.get_yticklabels().set_fontsize(9)\n", + "for q in qr:\n", + " ax.axvline(q, c='r', lw=1)\n" + ] + }, + { + "cell_type": "markdown", + "id": "6756a67c-538e-45db-be9e-729b4d2a3e8d", + "metadata": {}, + "source": [ + "## Testing gaussian sampling" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "ec9c4d5f-c676-46f5-8a3b-4d3f3773ef7d", + "metadata": {}, + "outputs": [], + "source": [ + "from scipy.stats import norm, uniform" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "9b8c74ca-9343-4779-b058-2006eee7e999", + "metadata": {}, + "outputs": [], + "source": [ + "pa = 0.15 # probability fraction we will want below Xa -- should be less than 0.25\n", + "pb = 0.10 # probability fraction over Xb -- should be less than 0.25" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "09c0c2ee-6875-4d49-8475-a96ff27de0d1", + "metadata": {}, + "outputs": [], + "source": [ + "# Standard deviation\n", + "sdev = (qr[0] - qr[2]) / (norm.ppf(pa) - norm.ppf(1-pb))\n", + "\n", + "# Mean\n", + "mean = qr[0] - norm.ppf(pa)*sdev\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "e5e60e04-d5f1-4d4e-a1e5-598eb58b1b91", + "metadata": {}, + "outputs": [], + "source": [ + "x = np.linspace(0, qr[2]+qr[0], 5000)\n", + "y = norm.cdf(x, loc=mean, scale=sdev)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "e78dc8bc-8ff1-451f-a41d-290c4074cef7", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "ax.plot(x, y);\n", + "for q, t, s in zip(qr, [\"$X_a$\", \"mean\", \"$X_b$\"], [\"-\", \":\", \"-\"]):\n", + " ax.axvline(q, c='r', lw=1, ls=s)\n", + " ax.text(q, -0.01, t, color=\"r\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "13da581c-e016-4566-a657-37a52f185970", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(0.1499999999999999, 0.8999999999999999)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Check the probabilities for the regions delimited by Xa & Xb\n", + "norm.cdf(qr[0], loc=mean, scale=sdev), norm.cdf(qr[2], loc=mean, scale=sdev)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "7dd7b607-2d21-4354-b1d9-cbf238ab9c0f", + "metadata": {}, + "outputs": [], + "source": [ + "def plot_quartiles(ax: plt.Axes):\n", + " for q, s in zip(qr, [\"-\", \":\", \"-\"]):\n", + " ax.axvline(q, c='r', lw=1, ls=s)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "39d29de0-2c60-44e6-81a9-a2c0e3598909", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "gauss_pdf = norm.pdf(x, loc=mean, scale=sdev)\n", + "\n", + "fig, ax = plt.subplots()\n", + "ax.plot(x, gauss_pdf);\n", + "plot_quartiles(ax)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "5f48fd92-9dde-4424-866d-45db2f0a67f2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.9999985080935955" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# area should be 1\n", + "sum(gauss_pdf)*(x[1]-x[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "86326b43-6e15-447d-8862-b01391563e2e", + "metadata": {}, + "outputs": [], + "source": [ + "pdf_max = np.max(gauss_pdf)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "6896daba-3ffa-4e79-a4fe-4e827add12c0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(0.0036956421869327907, 0.003695642737133491)" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pdf_max, norm.pdf(mean, loc=mean, scale=sdev)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "608e0200-a168-40ac-969b-78498379174f", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "ax.plot(x, gauss_pdf/pdf_max);\n", + "plot_quartiles(ax)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "c4731ed9-de6c-4b9c-bd68-1b2c08f67d06", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "interval = [0, qr[2]+qr[0]]\n", + "width = interval[1] - interval[0]\n", + "\n", + "# we generate random data on the interval \n", + "data_unif = uniform.rvs(*interval, 300000)\n", + "\n", + "# Let's plot the distribution for the test data. Should be flat\n", + "fig, ax = plt.subplots()\n", + "ax.hist(data_unif, bins=300);\n", + "plot_quartiles(ax)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "6de1d210-28cd-4308-b4db-270d5f34b996", + "metadata": {}, + "outputs": [], + "source": [ + "def subsample_gauss(data: np.ndarray, mean: float, sdev: float, norm_factor: float) -> np.ndarray:\n", + " \"\"\"\n", + " Vectorized subsampling: process the whole dataset\n", + " \"\"\"\n", + " # Create the gaussian weight for each data point\n", + " p = norm.pdf(data, loc=mean, scale=sdev)/norm_factor\n", + " #print(p)\n", + " # Subsample data with probability according to the weight\n", + " return data[ uniform.rvs(size=len(p)) < p ]" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "0485d7fb-bc65-4cf0-b9c1-848c67bd885e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.2586459880256598" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ratio = 1/(width*pdf_max)\n", + "ratio\n" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "3cac0c03-45b9-487f-80a8-bd73d6d86bcf", + "metadata": {}, + "outputs": [], + "source": [ + "data_unif_sub = subsample_gauss(data_unif, mean, sdev, pdf_max)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "c8a461df-801a-4dcf-a33b-6b9d39a5bea5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.25728" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(data_unif_sub)/len(data_unif)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "cbbccdf5-a5cb-47b6-843d-1ecacfea6db2", + "metadata": {}, + "outputs": [], + "source": [ + "def check_regions(data: np.ndarray):\n", + " ra = len(data[data < qr[0]]) / len(data)\n", + " print(\"Probability mass below Pa:\", ra)\n", + " rb = len(data[data > qr[2]]) / len(data)\n", + " print(\"Probability mass above Pb:\", rb)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "b9270b56-f415-4e50-ad1e-c9f3215bb6aa", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Probability mass below Pa: 0.15082141376451078\n", + "Probability mass above Pb: 0.10037054311774461\n" + ] + } + ], + "source": [ + "check_regions(data_unif_sub)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "ed1c063a-e712-49db-8d4a-8cfd38b95c4c", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "ax.hist(data_unif_sub, bins=200);\n", + "plot_quartiles(ax)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "e3265730-13b3-4f2e-af55-60ea13d69453", + "metadata": {}, + "outputs": [], + "source": [ + "# Let's seek a sampling ratio of 20%\n", + "desired_sampling_ratio = 0.20\n", + "\n", + "# With this desired fraction, we compute the new normalization factor\n", + "unif_norm_factor = 1/(width*desired_sampling_ratio)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "d37df65d-ea77-4743-b424-5d94dedc873d", + "metadata": {}, + "outputs": [], + "source": [ + "data_unif_sub2 = subsample_gauss(data_unif, mean, sdev, unif_norm_factor)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "afddd196-15aa-42f8-88b5-aca67385dee9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.19953666666666667" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Now we have our desired sampling ratio\n", + "len(data_unif_sub2)/len(data_unif)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "dceb49cc-2b80-4490-b01c-4f65d979ec34", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Probability mass below Pa: 0.152854112026194\n", + "Probability mass above Pb: 0.09983127578891098\n" + ] + } + ], + "source": [ + "# And the probability masses stay as before\n", + "check_regions(data_unif_sub2)" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "eb4a19f7-591a-454a-b077-9311d9617b13", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "ax.hist(data_unif_sub2, bins=200);\n", + "plot_quartiles(ax)" + ] + }, + { + "cell_type": "markdown", + "id": "cee4297c-dfad-494f-9696-f10dbc346691", + "metadata": {}, + "source": [ + "## Applying gaussian sampling" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "1e270d91-f208-4a04-a296-c150cd11da4b", + "metadata": {}, + "outputs": [], + "source": [ + "# Let's seek a sampling ratio of 20%\n", + "desired_sampling_ratio = 0.20\n", + "\n", + "# With this desired fraction, we compute the new normalization factor\n", + "unif_norm_factor = 1/(width*desired_sampling_ratio)" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "a20f8ef7-5cd5-4f6f-bc82-5e418bbdf992", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 44.1 ms, sys: 12.2 ms, total: 56.3 ms\n", + "Wall time: 54.3 ms\n" + ] + } + ], + "source": [ + "%%time\n", + "data_sub = subsample_gauss(data, mean, sdev, unif_norm_factor)" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "614012b5-900b-4e17-bf4d-2b679bd8cef0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.396047" + ] + }, + "execution_count": 56, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Compute the sampling ratio we have achieved\n", + "len(data_sub)/len(data)" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "8eb74f54-c7e9-45dd-ba1a-1cb7ba1bc9f4", + "metadata": {}, + "outputs": [], + "source": [ + "def plot_subsample(orig: np.ndarray, sub: np.ndarray, \n", + " name: str = 'Gaussian subsampling'):\n", + " fig, ax = plt.subplots()\n", + " if orig is not None:\n", + " ax.hist(orig, bins=1000, range=[0, qr[2]*3]);\n", + " ax.hist(sub, bins=1000, range=[0, qr[2]*3], color=\"g\");\n", + " if orig is not None:\n", + " ax.legend(['original', 'subsampled'])\n", + " plot_quartiles(ax)\n", + " ax.set_title(\"Perplexity distribution \" + (\"before and after \" if orig is not None else \"for \") + name);" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "1cab5b83-1f40-41db-87ae-33d67775142e", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_subsample(data, data_sub)" + ] + }, + { + "cell_type": "markdown", + "id": "c444f307-d5a3-4ec1-845d-340e3de9ed06", + "metadata": {}, + "source": [ + "## Using PerplexitySubsampler" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "id": "a6d52c72-b54f-43d6-91f2-99a771335bee", + "metadata": {}, + "outputs": [], + "source": [ + "from subsampler import PerplexitySubsampler" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "id": "b305f48b-e44c-4585-9c93-a1d9b0ac41f8", + "metadata": {}, + "outputs": [], + "source": [ + "subsampler = PerplexitySubsampler(data)" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "3f3dd948-d3fa-4168-9653-cf6135c89490", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Optimization terminated successfully.\n", + " Current function value: 0.000000\n", + " Iterations: 81\n", + " Function evaluations: 156\n" + ] + } + ], + "source": [ + "subsampler.set(ratio=0.30, pa=0.20, pb=0.05)" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "id": "0a3232d5-ac30-4118-9ac5-6278da937416", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(0.006131260240020625, 488.615463864415, 124.17368632962524)" + ] + }, + "execution_count": 66, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "subsampler.norm, subsampler.mean, subsampler.sdev" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b82b081d-5f9e-4dd1-b5f3-c6aa2b2b4aa0", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/perplexity.py b/perplexity.py new file mode 100644 index 0000000000000000000000000000000000000000..370e70cea9ddf030353d24622f780aa482d4500b --- /dev/null +++ b/perplexity.py @@ -0,0 +1,449 @@ +import argparse +import json +import re +import os +from functools import cache +from pathlib import Path +from typing import Iterator, List, NoReturn, Optional, Tuple, Union + +import kenlm +import msgspec +import sentencepiece +from numpy.random import default_rng +from scipy.stats import norm +from tqdm import tqdm + +from normalization import normalize_text + + +RNG = default_rng() +LANGS = ("no", "nn", "nob", "nno", "da", "sv", "is", "en") +DEFAULT_LANG = "no" +BASEPATH = Path(os.environ.get("PERPLEXITY_BASEPATH", "/nfsmounts/datastore/mimir/perplexity")) +CONFIG = { + "harmful": { + "no": {"model": BASEPATH / "kenlm" / "harmful" / "no.bin", "normalize": True}, + "nn": {"model": BASEPATH / "kenlm" / "harmful" / "no.bin", "normalize": True}, + "nob": {"model": BASEPATH / "kenlm" / "harmful" / "no.bin", "normalize": True}, + "nno": {"model": BASEPATH / "kenlm" / "harmful" / "no.bin", "normalize": True}, + "da": {"model": BASEPATH / "kenlm" / "harmful" / "da.bin", "normalize": True}, + "sv": {"model": BASEPATH / "kenlm" / "harmful" / "sv.bin", "normalize": True}, + "is": {"model": BASEPATH / "kenlm" / "harmful" / "is.bin", "normalize": True}, + "en": {"model": BASEPATH / "kenlm" / "harmful" / "en.bin", "normalize": True}, + }, + "wikipedia": { + "no": { + "model": BASEPATH / "kenlm" / "wikipedia" / "no.arpa.bin", + "tokenizer": BASEPATH / "spm" / "wikipedia" / "no.sp.model", + "normalize": True + }, + "nn": { + "model": BASEPATH / "kenlm" / "wikipedia" / "nn.arpa.bin", + "tokenizer": BASEPATH / "spm" / "wikipedia" / "nn.sp.model", + "normalize": True + }, + "nob": { + "model": BASEPATH / "kenlm" / "wikipedia" / "no.arpa.bin", + "tokenizer": BASEPATH / "spm" / "wikipedia" / "no.sp.model", + "normalize": True + }, + "nno": { + "model": BASEPATH / "kenlm" / "wikipedia" / "nn.arpa.bin", + "tokenizer": BASEPATH / "spm" / "wikipedia" / "nn.sp.model", + "normalize": True + }, + "da": { + "model": BASEPATH / "kenlm" / "wikipedia" / "da.arpa.bin", + "tokenizer": BASEPATH / "spm" / "wikipedia" / "da.sp.model", + "normalize": True + }, + "en": { + "model": BASEPATH / "kenlm" / "wikipedia" / "en.arpa.bin", + "tokenizer": BASEPATH / "spm" / "wikipedia" / "en.sp.model", + "normalize": True + }, + "is": { + "model": BASEPATH / "kenlm" / "wikipedia" / "is.arpa.bin", + "tokenizer": BASEPATH / "spm" / "wikipedia" / "is.sp.model", + "normalize": True + }, + "sv": { + "model": BASEPATH / "kenlm" / "wikipedia" / "sv.arpa.bin", + "tokenizer": BASEPATH / "spm" / "wikipedia" / "sv.sp.model", + "normalize": True + }, + }, + "books": { + "model": BASEPATH / "kenlm" / "books.norm.sp.arpa.bin", + "tokenizer": BASEPATH / "spm" / "books.norm.sp.model", + "normalize": True + }, + "newspapers": { + "model": BASEPATH / "kenlm" / "newspapers.norm.sp.arpa.bin", + "tokenizer": BASEPATH / "spm" / "newspapers.norm.sp.model", + "normalize": True + }, + "maalfrid": { + "model": BASEPATH / "kenlm" / "maalfrid.norm.sp.arpa.bin", + "tokenizer": BASEPATH / "spm" / "maalfrid.norm.sp.model", + "normalize": True + } +} + +# Not used anymore, speed is almost same as naive algorithm +# class PerplexityDoc(msgspec.Struct): +# id: str +# doc_type: str +# publish_year: int +# lang_fasttext: str +# lang_fasttext_conf: Union[str, float] +# text: str +# perplexity: float | None = -1.0 +# perplexity_model: str | None = None +# harmful_pp: float | None = None +# # wikipedia_pp: float | None = None +# # books_pp: float | None = None +# # newspapers_pp: float | None = None +# # maalfrid_pp: float | None = None + + +def should_keep( + perp: float, dist_norm: float, dist_mean: float, dist_std: float +) -> bool: + """ + Decide if a doc is to be retained based on its perplexity value + Note: set() must have been called previously + """ + p = norm.pdf(perp, loc=dist_mean, scale=dist_std) / dist_norm + return RNG.uniform() < p + + +def fix_language(language: str) -> str: + if language not in LANGS: + return DEFAULT_LANG + else: + return language + + +def pp(log_score, length): + return 10.0 ** (-log_score / length) + + +@cache +def load_kenlm(model: str) -> kenlm.Model: + lm_config = kenlm.Config() + lm_config.load_method = 2 + return kenlm.Model(str(model), lm_config) + + +@cache +def load_sentencepiece(model: str) -> sentencepiece.SentencePieceProcessor: + sp = sentencepiece.SentencePieceProcessor() + sp.load(str(model)) + return sp + + +def get_perplexity( + document: str, + model: str, + tokenizer: str=None, + normalize: bool=False +) -> float: + lines = document.split("\n") + model = load_kenlm(model) + if not lines or not model: + return 0.0 + if tokenizer: + sp = load_sentencepiece(tokenizer) + doc_log_score, doc_length = 0, 0 + for line in lines: + if not line: + continue + if normalize: + line = normalize_text(line) + if tokenizer: + line = " ".join(sp.encode_as_pieces(line)) + log_score = model.score(line) + length = len(line.split()) + 1 + doc_log_score += log_score + doc_length += length + + return round(pp(doc_log_score, doc_length), 1) + + +def get_perplexity_local( + document: str, + model: kenlm.Model, + tokenizer: sentencepiece.SentencePieceProcessor=None, + normalize: bool=False +) -> float: + lines = document.split("\n") + if not lines or not model: + return 0.0 + doc_log_score, doc_length = 0, 0 + for line in lines: + if normalize: + line = normalize_text(line) + if tokenizer is not None: + line = " ".join(tokenizer.encode_as_pieces(line)) + log_score = model.score(line) + length = len(line.split()) + 1 + doc_log_score += log_score + doc_length += length + + return round(pp(doc_log_score, doc_length), 1) + + +def harmful_perplexity(document: str, language: str) -> float: + params = CONFIG["harmful"][fix_lang(language)] + return get_perplexity(document=document, **params) + + +def wikipedia_perplexity(document: str, language: str) -> float: + params = CONFIG["wikipedia"][fix_lang(language)] + return get_perplexity(document=document, **params) + + +def books_perplexity(document: str) -> float: + params = CONFIG["books"] + return get_perplexity(document=document, **params) + + +def newspapers_perplexity(document: str) -> float: + params = CONFIG["newspapers"] + return get_perplexity(document=document, **params) + + +def maalfrid_perplexity(document: str) -> float: + params = CONFIG["maalfrid"] + return get_perplexity(document=document, **params) + + +def source_perplexities( + document: str, + language: str, + model: str | None = None, + include_harmful: bool=True) -> float: + """Calculates all models perplexities at once""" + # Since normalization is applied to all, we normalize first and set it to False + normalized_document = "\n".join(normalize_text(line) for line in document.split("\n")) + language = fix_language(language) + + if model is not None: + params = CONFIG[model] + if model == "wikipedia": + params = params[language] + params.update({"normalize": False}) + perplexity = get_perplexity(document=normalized_document, **params) + perplexities = { + f"{model}_pp": perplexity, + } + else: + params = CONFIG["wikipedia"][language] + params.update({"normalize": False}) + wikipedia_perplexity = get_perplexity(document=normalized_document, **params) + + params = CONFIG["books"] + params.update({"normalize": False}) + books_perplexity = get_perplexity(document=normalized_document, **params) + + params = CONFIG["newspapers"] + params.update({"normalize": False}) + newspapers_perplexity = get_perplexity(document=normalized_document, **params) + + params = CONFIG["maalfrid"] + params.update({"normalize": False}) + maalfrid_perplexity = get_perplexity(document=normalized_document, **params) + perplexities = { + "wikipedia_pp": wikipedia_perplexity, + "books_pp": books_perplexity, + "newspapers_pp": newspapers_perplexity, + "maalfrid_pp": maalfrid_perplexity, + } + if include_harmful: + params = CONFIG["harmful"][language] + params.update({"normalize": False}) + harmful_perplexity = get_perplexity(document=normalized_document, **params) + perplexities.update({ + "harmful_pp": harmful_perplexity, + }) + return perplexities + + +def get_model_for(doc_type: str) -> (str, bool): + """Returns model type and if it needs a language variant""" + doc_type = doc_type.split("_", 1)[0] + if "-" in doc_type: + doc_type = doc_type.split("-", 1)[-1] + if doc_type in ("book", "books"): + return "books", False + elif doc_type in ("culturax", "slimpajama", "wikipedia", "digimanus", "pg19", "hplt", "starcoder"): + return "wikipedia", True + elif doc_type in ("newspaper", "newspapers"): + return "newspapers", False + elif doc_type in ("evalueringsrapport", "lovdata", "maalfrid", "parlamint"): + return "maalfrid", False + else: + return "wikipedia", True + + +def preload_models_tokenizers() -> List: + print("Preloading models...", end=" ") + models = { + "books": ( + load_kenlm(BASEPATH / "kenlm" / "books.norm.arpa.bin"), + load_sentencepiece(BASEPATH / "spm" / "books.norm.sp.model") + ), + "newspapers": ( + load_kenlm(BASEPATH / "kenlm" / "newspapers.norm.arpa.bin"), + load_sentencepiece(BASEPATH / "spm" / "newspapers.norm.sp.model") + ), + "maalfrid": ( + load_kenlm(BASEPATH / "kenlm" / "maalfrid.norm.arpa.bin"), + load_sentencepiece(BASEPATH / "spm" / "maalfrid.norm.sp.model") + ), + } + for lang, params in CONFIG["harmful"].items(): + model = load_kenlm(params["model"]) + models[f"harmful-{lang}"] = model, None + + for lang, params in CONFIG["wikipedia"].items(): + model = load_kenlm(params["model"]) + tokenizer = load_sentencepiece(params["tokenizer"]) + models[f"wikipedia-{lang}"] = model, tokenizer + print("Done") + return models + + +# Not used anymore, speed is almost same as naive algorithm +# def process_file_binary(input_file, output_path, cutoff=None, overwrite_output=True): +# input_file = Path(input_file) +# output_file = Path(output_path) / input_file.name +# if not overwrite_output and output_file.exists(): +# print(f"Skipping {output_file} as it already exists") +# return +# models = preload_models_tokenizers() +# encoder = msgspec.json.Encoder() +# decoder = msgspec.json.Decoder(PerplexityDoc) +# buffer = bytearray(64) +# with (open(output_file, 'wb') as f, +# open(input_file, 'r', encoding='utf-8') as lines): +# for line_count, line in tqdm(enumerate(lines), desc=f"Processing {input_file.name}"): +# doc = decoder.decode(line) +# if "code" not in doc.doc_type: +# # Perplexity +# model_type, needs_lang = get_model_for(doc.doc_type) +# if needs_lang: +# model_key = f"{model_type}-{fix_language(doc.lang_fasttext)}" +# else: +# model_key = model_type +# model, tokenizer = models[model_key] +# text = "\n".join(normalize_text(line) for line in doc.text.split("\n")) +# score = get_perplexity_local( +# text, model=model, tokenizer=tokenizer, normalize=False +# ) +# doc.perplexity = score +# doc.perplexity_model = model_type +# # Harmfulness +# harmful_key = f"harmful-{fix_language(doc.lang_fasttext)}" +# harmful_model, harmful_tokenizer = models[harmful_key] +# harmful_pp = get_perplexity_local( +# text, model=harmful_model, tokenizer=harmful_tokenizer, normalize=False +# ) +# doc.harmful_pp = harmful_pp + +# encoder.encode_into(doc, buffer) +# buffer.extend(b"\n") +# f.write(buffer) +# if cutoff is not None and line_count >= cutoff: +# break + + +def process_file(input_file, output_path, cutoff=None, model=None, overwrite_output=True): + """ + Processes a file by reading its contents, analyzing each line for language and document type, + computing perplexities using specified models, and writing the modified content to a new file. + + This function performs several steps: + 1. Determines the output file path and checks for its existence if overwrite is not desired. + 2. Reads the input file line by line, processing each line as a separate JSON document. + 3. For each document, identifies its language using a fastText model. If the document type is "starcoder", + it defaults the language to English. + 4. Depending on the model parameter, computes perplexities for the document text either using a + single document type model or a specified general model. + 5. Updates the document with computed perplexities and writes it to the output file in JSON format. + 6. Optionally stops processing after a specified number of lines determined by the cutoff parameter. + + Parameters: + - input_file (str or Path): Path to the input file to be processed. + - output_path (str or Path): Directory path where the output file will be saved. The output file + will have the same name as the input file. + - cutoff (int, optional): If provided, processing will stop after this number of lines. Defaults to None. + - model (str, optional): Specifies the model to use for computing perplexities. If 'single', uses a + model specific to the document's type. Otherwise, uses the model specified. + Defaults to None. + - overwrite_output (bool): If True, will overwrite the output file if it already exists. If False, + will skip processing if the output file exists. Defaults to True. + + Returns: + None. Writes processed documents to an output file in the specified output path. + """ + input_file = Path(input_file) + output_file = Path(output_path) / input_file.name + if not overwrite_output and output_file.exists(): + print(f"Skipping {output_file} as it already exists") + return + with (open(output_file, 'w', encoding='utf-8') as f, + open(input_file, 'r', encoding='utf-8') as lines): + for line_count, line in tqdm(enumerate(lines), desc=f"Processing {input_file.name}"): + doc = json.loads(line) + language = doc["lang_fasttext"] + if doc["doc_type"] == "starcoder": + language = "en" + if model == "single": + doc_type_model, _ = get_model_for(doc["doc_type"]) + perplexities = source_perplexities(doc["text"], language, model=doc_type_model) + perplexities["perplexity"] = perplexities.pop(f"{doc_type_model}_pp") + perplexities["perplexity_model"] = doc_type_model + else: + perplexities = source_perplexities(doc["text"], language, model=model) + doc.update(perplexities) + f.write(json.dumps(doc) + "\n") + if cutoff is not None and line_count >= cutoff: + break + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Calculate perplexity values for a given JSON Lines file and output the result to a new file.') + parser.add_argument('-i', '--input_file', type=str, + help='Input file path') + parser.add_argument('-o', '--output_path', type=str, + help='Output path to write enriched file') + parser.add_argument('-c', '--cutoff', required=False, type=int, + help='Max number of lines to process') + parser.add_argument('-m', '--model', required=False, type=str, + help='Run "single" model per doc type, "all" the models, ' + 'or a specific model to choose from ' + '"books", "wikipedia", "newspapers" or "maalfrid". ' + 'Defaults to "single"') + parser.add_argument('--overwrite_output', + action=argparse.BooleanOptionalAction, default=True, + help="Whether to overwrite the output file if exists.") + + args = parser.parse_args() + + if args.model == "single": + process_file( + args.input_file, args.output_path, args.cutoff, + model="single", overwrite_output=args.overwrite_output, + ) + elif args.model in ("books", "wikipedia", "newspapers", "maalfrid"): + process_file( + args.input_file, args.output_path, args.cutoff, + model=args.model, overwrite_output=args.overwrite_output, + ) + else: + process_file( + args.input_file, args.output_path, args.cutoff, + overwrite_output=args.overwrite_output, + ) diff --git a/plots/all_doc_types_plots.png b/plots/all_doc_types_plots.png new file mode 100644 index 0000000000000000000000000000000000000000..c9739ac8a2d5033749d107d1ce460f55c50931eb Binary files /dev/null and b/plots/all_doc_types_plots.png differ diff --git a/plots/book_no_book.png b/plots/book_no_book.png new file mode 100644 index 0000000000000000000000000000000000000000..c5ea467ee36bc58da968a9eea0aac3de1af87cfc Binary files /dev/null and b/plots/book_no_book.png differ diff --git a/plots/books_pdf_no_books_pdf.png b/plots/books_pdf_no_books_pdf.png new file mode 100644 index 0000000000000000000000000000000000000000..691b7edbd4381bda95d81428b9847949d303abed Binary files /dev/null and b/plots/books_pdf_no_books_pdf.png differ diff --git a/plots/combined_plots.png b/plots/combined_plots.png new file mode 100644 index 0000000000000000000000000000000000000000..faaa5a20c7ba1f36b632a4c042fcce08b268ea6e Binary files /dev/null and b/plots/combined_plots.png differ diff --git a/plots/culturax_nob_all_plots.png b/plots/culturax_nob_all_plots.png new file mode 100644 index 0000000000000000000000000000000000000000..41552542cd153cc00a72ea1bee18713d5048a5f6 Binary files /dev/null and b/plots/culturax_nob_all_plots.png differ diff --git a/plots/culturax_nob_culturax.png b/plots/culturax_nob_culturax.png new file mode 100644 index 0000000000000000000000000000000000000000..f67022fc3b4e9151256e0a88d3ae85187866904f Binary files /dev/null and b/plots/culturax_nob_culturax.png differ diff --git a/plots/plots_book.png b/plots/plots_book.png new file mode 100644 index 0000000000000000000000000000000000000000..19f30d2084e5aa1fc8d107897e01b749de2617b4 Binary files /dev/null and b/plots/plots_book.png differ diff --git a/plots/plots_books_pdf.png b/plots/plots_books_pdf.png new file mode 100644 index 0000000000000000000000000000000000000000..9b562fcb555def7435b05c41892a739d2e38cb18 Binary files /dev/null and b/plots/plots_books_pdf.png differ diff --git a/plots/plots_culturax.png b/plots/plots_culturax.png new file mode 100644 index 0000000000000000000000000000000000000000..2cd5207b90a055f1a0695f521e98cafce4b70ad0 Binary files /dev/null and b/plots/plots_culturax.png differ diff --git a/plots/plots_evalueringsrapport_pdf.png b/plots/plots_evalueringsrapport_pdf.png new file mode 100644 index 0000000000000000000000000000000000000000..89899b7768a0bc63bfd5d5e97eb7864b138e3a30 Binary files /dev/null and b/plots/plots_evalueringsrapport_pdf.png differ diff --git a/plots/plots_evalueringsrapport_pdf_no.png b/plots/plots_evalueringsrapport_pdf_no.png new file mode 100644 index 0000000000000000000000000000000000000000..48bcd87823e2cbfd56e7b6cc01f9f7b3485b1ce8 Binary files /dev/null and b/plots/plots_evalueringsrapport_pdf_no.png differ diff --git a/plots/plots_lovdata_cd_lokaleforskrifter_2005.png b/plots/plots_lovdata_cd_lokaleforskrifter_2005.png new file mode 100644 index 0000000000000000000000000000000000000000..022ca0456fffa242097e0ca8cbfac07a9f87304a Binary files /dev/null and b/plots/plots_lovdata_cd_lokaleforskrifter_2005.png differ diff --git a/plots/plots_lovdata_cd_lokaleforskrifter_2005_no.png b/plots/plots_lovdata_cd_lokaleforskrifter_2005_no.png new file mode 100644 index 0000000000000000000000000000000000000000..2fdf3746457fd8f3a8a1c21c44410f1cbec63001 Binary files /dev/null and b/plots/plots_lovdata_cd_lokaleforskrifter_2005_no.png differ diff --git a/plots/plots_lovdata_cd_norgeslover_2005.png b/plots/plots_lovdata_cd_norgeslover_2005.png new file mode 100644 index 0000000000000000000000000000000000000000..03a1f54493992f2902fca6693b6231aed63e7d49 Binary files /dev/null and b/plots/plots_lovdata_cd_norgeslover_2005.png differ diff --git a/plots/plots_lovdata_cd_norgeslover_2005_no.png b/plots/plots_lovdata_cd_norgeslover_2005_no.png new file mode 100644 index 0000000000000000000000000000000000000000..a26537580997ec61af8ff402176d9c259ae2d54f Binary files /dev/null and b/plots/plots_lovdata_cd_norgeslover_2005_no.png differ diff --git a/plots/plots_lovdata_cd_odelsting_2005.png b/plots/plots_lovdata_cd_odelsting_2005.png new file mode 100644 index 0000000000000000000000000000000000000000..b4df64bc8e7d2a59dcf4ddad8ae0b28503e8fb01 Binary files /dev/null and b/plots/plots_lovdata_cd_odelsting_2005.png differ diff --git a/plots/plots_lovdata_cd_odelsting_2005_no.png b/plots/plots_lovdata_cd_odelsting_2005_no.png new file mode 100644 index 0000000000000000000000000000000000000000..3422dc69a6e4c9ece7e788c0b7a8fb8e3de216ce Binary files /dev/null and b/plots/plots_lovdata_cd_odelsting_2005_no.png differ diff --git a/plots/plots_lovdata_cd_rtv_rundskriv_2005.png b/plots/plots_lovdata_cd_rtv_rundskriv_2005.png new file mode 100644 index 0000000000000000000000000000000000000000..c92d0bbc3a972b783ff1f7dc80fba912fbe5417b Binary files /dev/null and b/plots/plots_lovdata_cd_rtv_rundskriv_2005.png differ diff --git a/plots/plots_lovdata_cd_rtv_rundskriv_2005_no.png b/plots/plots_lovdata_cd_rtv_rundskriv_2005_no.png new file mode 100644 index 0000000000000000000000000000000000000000..162cf5190bcae8a6fee828f81c4f10f9f82dc3e3 Binary files /dev/null and b/plots/plots_lovdata_cd_rtv_rundskriv_2005_no.png differ diff --git a/plots/plots_lovdata_cd_rundskriv_lovavdeling_2005.png b/plots/plots_lovdata_cd_rundskriv_lovavdeling_2005.png new file mode 100644 index 0000000000000000000000000000000000000000..d7ddbd05189bd8817c3a83f85d71a9db93cb9fa7 Binary files /dev/null and b/plots/plots_lovdata_cd_rundskriv_lovavdeling_2005.png differ diff --git a/plots/plots_lovdata_cd_rundskriv_lovavdeling_2005_no.png b/plots/plots_lovdata_cd_rundskriv_lovavdeling_2005_no.png new file mode 100644 index 0000000000000000000000000000000000000000..854093f8c70366dcc62e68f0504387f909b2f6ef Binary files /dev/null and b/plots/plots_lovdata_cd_rundskriv_lovavdeling_2005_no.png differ diff --git a/plots/plots_lovdata_cd_sentrale_forskrifter_2005.png b/plots/plots_lovdata_cd_sentrale_forskrifter_2005.png new file mode 100644 index 0000000000000000000000000000000000000000..bf575557900247926cf433866d25d0a1e44f05d1 Binary files /dev/null and b/plots/plots_lovdata_cd_sentrale_forskrifter_2005.png differ diff --git a/plots/plots_lovdata_cd_sentrale_forskrifter_2005_no.png b/plots/plots_lovdata_cd_sentrale_forskrifter_2005_no.png new file mode 100644 index 0000000000000000000000000000000000000000..62a31669fd4b7c01cd8da911c6c14091fc9bfbc7 Binary files /dev/null and b/plots/plots_lovdata_cd_sentrale_forskrifter_2005_no.png differ diff --git a/plots/plots_lovdata_cd_skatt_rundskriv_2005.png b/plots/plots_lovdata_cd_skatt_rundskriv_2005.png new file mode 100644 index 0000000000000000000000000000000000000000..da6dfbd48901da00fd61d80d8f841fbffb7dc4c8 Binary files /dev/null and b/plots/plots_lovdata_cd_skatt_rundskriv_2005.png differ diff --git a/plots/plots_lovdata_cd_skatt_rundskriv_2005_no.png b/plots/plots_lovdata_cd_skatt_rundskriv_2005_no.png new file mode 100644 index 0000000000000000000000000000000000000000..93e926d6b3fc2b433918c7c6b8a214fd7c2c217e Binary files /dev/null and b/plots/plots_lovdata_cd_skatt_rundskriv_2005_no.png differ diff --git a/plots/plots_lovdata_cd_somb_rundskriv_2005.png b/plots/plots_lovdata_cd_somb_rundskriv_2005.png new file mode 100644 index 0000000000000000000000000000000000000000..d9ae825140a28b8dc7aed90cc05632ad97c3c639 Binary files /dev/null and b/plots/plots_lovdata_cd_somb_rundskriv_2005.png differ diff --git a/plots/plots_lovdata_cd_somb_rundskriv_2005_no.png b/plots/plots_lovdata_cd_somb_rundskriv_2005_no.png new file mode 100644 index 0000000000000000000000000000000000000000..0a2725722de25659576fa2c06b3340f230217044 Binary files /dev/null and b/plots/plots_lovdata_cd_somb_rundskriv_2005_no.png differ diff --git a/plots/plots_maalfrid_crawl_doc.png b/plots/plots_maalfrid_crawl_doc.png new file mode 100644 index 0000000000000000000000000000000000000000..b115c2ddc49b53bee067a4304102b59adaa2ef91 Binary files /dev/null and b/plots/plots_maalfrid_crawl_doc.png differ diff --git a/plots/plots_maalfrid_crawl_doc_no.png b/plots/plots_maalfrid_crawl_doc_no.png new file mode 100644 index 0000000000000000000000000000000000000000..03e863f4a0eba147fa491710beec32429863015e Binary files /dev/null and b/plots/plots_maalfrid_crawl_doc_no.png differ diff --git a/plots/plots_maalfrid_crawl_html.png b/plots/plots_maalfrid_crawl_html.png new file mode 100644 index 0000000000000000000000000000000000000000..a23ac4bfd8e7ecab0b473c0c3a58090abd39dc86 Binary files /dev/null and b/plots/plots_maalfrid_crawl_html.png differ diff --git a/plots/plots_maalfrid_crawl_html_no.png b/plots/plots_maalfrid_crawl_html_no.png new file mode 100644 index 0000000000000000000000000000000000000000..d2b895ae7f1d05b4add5e2cae55156df65b85cc2 Binary files /dev/null and b/plots/plots_maalfrid_crawl_html_no.png differ diff --git a/plots/plots_maalfrid_crawl_pdf.png b/plots/plots_maalfrid_crawl_pdf.png new file mode 100644 index 0000000000000000000000000000000000000000..b132b09cb42cc0f04c0264c8680c6f5e0fbd7cff Binary files /dev/null and b/plots/plots_maalfrid_crawl_pdf.png differ diff --git a/plots/plots_maalfrid_crawl_pdf_no.png b/plots/plots_maalfrid_crawl_pdf_no.png new file mode 100644 index 0000000000000000000000000000000000000000..d7e2dbf6d32c34623fc22eb7330ec5e9efb25e7a Binary files /dev/null and b/plots/plots_maalfrid_crawl_pdf_no.png differ diff --git a/plots/plots_maalfrid_government_html.png b/plots/plots_maalfrid_government_html.png new file mode 100644 index 0000000000000000000000000000000000000000..d83a24c4babc314109c209da8da37228ca47135b Binary files /dev/null and b/plots/plots_maalfrid_government_html.png differ diff --git a/plots/plots_maalfrid_government_html_no.png b/plots/plots_maalfrid_government_html_no.png new file mode 100644 index 0000000000000000000000000000000000000000..f29dc046d95e50709ca4c99d99d7b73925b77faf Binary files /dev/null and b/plots/plots_maalfrid_government_html_no.png differ diff --git a/plots/plots_maalfrid_government_pdf.png b/plots/plots_maalfrid_government_pdf.png new file mode 100644 index 0000000000000000000000000000000000000000..31c164aa457267dbbbddadf1039ad9812f5a5de4 Binary files /dev/null and b/plots/plots_maalfrid_government_pdf.png differ diff --git a/plots/plots_maalfrid_government_pdf_no.png b/plots/plots_maalfrid_government_pdf_no.png new file mode 100644 index 0000000000000000000000000000000000000000..37648b087cf6c57aa67be49adb3e332b2884e8b7 Binary files /dev/null and b/plots/plots_maalfrid_government_pdf_no.png differ diff --git a/plots/plots_maalfrid_pp_no.png b/plots/plots_maalfrid_pp_no.png new file mode 100644 index 0000000000000000000000000000000000000000..0038c648f3c178918e453731669e02e8e8bf783e Binary files /dev/null and b/plots/plots_maalfrid_pp_no.png differ diff --git a/plots/plots_parlamint_xml.png b/plots/plots_parlamint_xml.png new file mode 100644 index 0000000000000000000000000000000000000000..dc37049042d2b89b2e8ee0c3b6c4a3478a577c90 Binary files /dev/null and b/plots/plots_parlamint_xml.png differ diff --git a/plots/plots_parlamint_xml_no.png b/plots/plots_parlamint_xml_no.png new file mode 100644 index 0000000000000000000000000000000000000000..9557a880bc654bbfd653a9b0aeac4d91215ec832 Binary files /dev/null and b/plots/plots_parlamint_xml_no.png differ diff --git a/samples/README.md b/samples/README.md new file mode 100644 index 0000000000000000000000000000000000000000..601c5ff3188b7d18302d57d8c6de256893043503 --- /dev/null +++ b/samples/README.md @@ -0,0 +1,19 @@ +Samples can be extracted using `shuf` from `clean_jsonl_3` folders. The need to be prefixed by the doc_type, and suffixed by the language code. + +For example: + +```bash +cat /nfsmounts/datastore/ncc_corpus/mimir/clean_jsonl_3/external-hplt*nn.jsonl |shuf -n 1000000 > hplt_nno.jsonl +cat /nfsmounts/datastore/ncc_corpus/mimir/clean_jsonl_3/external-hplt*nb.jsonl |shuf -n 1000000 > hplt_nob.jsonl +cat /nfsmounts/datastore/ncc_corpus/mimir/clean_jsonl_3/external-hplt*da.jsonl |shuf -n 1000000 > hplt_da.jsonl +cat /nfsmounts/datastore/ncc_corpus/mimir/clean_jsonl_3/external-hplt*sv.jsonl |shuf -n 1000000 > hplt_sv.jsonl +cat /nfsmounts/datastore/ncc_corpus/mimir/clean_jsonl_3/external-hplt*is.jsonl |shuf -n 1000000 > hplt_is.jsonl +``` + +Or for the restricted books (as they are longer per document, with 100,000 should be enough): + +```bash +cat /nfsmounts/datastore/ncc_corpus/mimir/clean_jsonl_3/restricted_books/restricted_books.*.jsonl|shuf -n 100000 > restricted_books_no.jsonl +``` + +Monitor memory usage while doing `shuf`. diff --git a/samples_quartiles.py b/samples_quartiles.py new file mode 100644 index 0000000000000000000000000000000000000000..087509d80730aa9ed7bb10ca5c8ff745b8e89c44 --- /dev/null +++ b/samples_quartiles.py @@ -0,0 +1,169 @@ +import argparse +import os +from collections import defaultdict +from io import StringIO + +import pandas as pd +from tqdm import tqdm + +from perplexity import get_model_for +from subsampler import PerplexitySubsampler + + +def process_files( + directory, + reject_level, + model_override, + output_file, + group_by_prefix_lang, + prefix_lang_mapping=None, + ratio=None, + ratio_per_lang=None, + pa=None, + pb=None, + include=None, +): + if ratio or ratio_per_lang: + rows = ["doc_type,model,language,reject,bad,medium,good,norm,mean,std"] + else: + rows = ["doc_type,model,language,reject,bad,medium,good"] + files = os.listdir(directory) + grouped_files = defaultdict(list) + if prefix_lang_mapping is None: + prefix_lang_mapping = {} + + # Group files by prefix and language if the option is enabled + description = "Processing files" + if group_by_prefix_lang: + description = "Processing files in groups" + for file in files: + parts = file.split('_') + prefix = parts[0] + if include and prefix not in include: + continue + lang = parts[-1].split(".")[0][:2] + group_key = prefix_lang_mapping.get(f"{prefix}_{lang}", f"{prefix}_{lang}") + grouped_files[group_key].append(file) + file_groups = grouped_files.values() + else: + file_groups = [] + for file in files: # Each file is its own group + if include and not any(file.startswith(prefix) for prefix in include): + continue + file_groups.append([file]) + + if output_file: + progress = tqdm(file_groups, desc=description) + else: + progress = file_groups + print(rows[0]) + # Process each group of files + for group in progress: + combined_perplexities = pd.DataFrame() + doc_type, lang = None, None + + for file in group: + if not doc_type or not lang: # Set doc_type and lang based on the first file + parts = file.split('_') + doc_type = file.split('_')[0] + lang = parts[-1].split(".")[0][:2] + doc_type, lang = prefix_lang_mapping.get(f"{doc_type}_{lang}", f"{doc_type}_{lang}").rsplit("_", 1) + perp = pd.read_json(os.path.join(directory, file), lines=True) + perplexities = pd.read_json(StringIO(perp["perplexities"].to_json(lines=True, orient="records")), lines=True) + combined_perplexities = pd.concat([combined_perplexities, perplexities], ignore_index=True) + + if model_override: + model = model_override + else: + model, _ = get_model_for(doc_type) + model_with_suffix = f"{model}_pp" + + # Calculate quantiles for the combined perplexities of the group + reject = round(combined_perplexities[model_with_suffix].quantile(q=reject_level), 2) + bad = round(combined_perplexities[model_with_suffix].quantile(q=0.75), 2) + medium = round(combined_perplexities[model_with_suffix].quantile(q=0.50), 2) + good = round(combined_perplexities[model_with_suffix].quantile(q=0.25), 2) + + if ratio: + subsampler = PerplexitySubsampler(combined_perplexities[model_with_suffix].values) + subsampler.set(ratio=ratio, pa=pa, pb=pb) + norm, mean, std = subsampler.norm, subsampler.mean, subsampler.sdev + sampling_stats = f",{norm},{mean},{std}" + elif ratio_per_lang: + subsampler = PerplexitySubsampler(combined_perplexities[model_with_suffix].values) + subsampler.set(ratio=ratio_per_lang.get(lang, ratio or 1.0), pa=pa, pb=pb) + norm, mean, std = subsampler.norm, subsampler.mean, subsampler.sdev + sampling_stats = f",{norm},{mean},{std}" + else: + sampling_stats = "" + + row = f"{doc_type},{model},{lang},{reject},{bad},{medium},{good}{sampling_stats}" + if output_file: + rows.append(row) + else: + print(row) + + + if output_file: + with open(output_file, "w") as f: + for row in rows: + f.write(f"{row}\n") + + +def main(): + """" + Each doc_type prefix needs to have an "no" lang, even of there's no real data. + These rows are crucial for the rest of the process. + """ + parser = argparse.ArgumentParser(description="Process files and compute perplexity metrics.") + parser.add_argument('directory', type=str, help='Directory containing the files to process') + parser.add_argument('--reject_level', type=float, default=0.95, help='Rejection quantile level (default: 0.95)') + parser.add_argument('--model_override', type=str, help='Override the model used') + parser.add_argument('--output_file', type=str, help='Output file in CSV format. If not given, prints to standard output.') + parser.add_argument('--group_by_prefix_lang', action='store_true', help='Group and calculate quantiles for files with the same prefix and language') + parser.add_argument('--overwrite_prefix_lang', type=str, help='Overwrite the assignment of languages to doc_type prefixes, e.g., "starcoder_en:starcoder_code,hplt_en:hplt_no"') + parser.add_argument('--sampling_ratio', type=float, help='Ratio of documents to keep for sampling. If passed, it generate distribution statistics (norm, mean, std) needed for sampling') + parser.add_argument('--sampling_ratio_per_lang', type=str, help='Ratio of documents per lang, e.g., "en:0.25,sv:0.34"') + parser.add_argument('--sampling_q1_prob', type=float, default=0.20, help='Probabilty for keeping documents in the Q1 range') + parser.add_argument('--sampling_q3_prob', type=float, default=0.05, help='Probabilty for keeping documents in the Q3 range') + parser.add_argument('--include', type=str, help='Comma separeted list of doc type prefixes to include') + + args = parser.parse_args() + + if args.sampling_ratio_per_lang: + # Turns "en: 0.25, sv : 0.34" into {'en': 0.25, 'sv': 0.34} + ratio_per_lang = dict( + (k.strip(), float(v.strip())) + for k, v in (item.split(":") + for item in args.sampling_ratio_per_lang.split(",") + ) + ) + else: + ratio_per_lang = None + if args.overwrite_prefix_lang: + # Turns "starcoder_en:starcoder_code,hplt_en:hplt_no" into {'starcoder_en': 'starcoder_code', 'hplt_en': 'hplt_no'} + prefix_lang_mapping = dict( + (k.strip(), v.strip()) + for k, v in (item.split(":") + for item in args.overwrite_prefix_lang.split(",") + ) + ) + else: + prefix_lang_mapping = {} + + process_files( + args.directory, + args.reject_level, + args.model_override, + args.output_file, + group_by_prefix_lang=args.group_by_prefix_lang, + prefix_lang_mapping=prefix_lang_mapping, + pa=args.sampling_q1_prob, + pb=args.sampling_q3_prob, + ratio=args.sampling_ratio, + ratio_per_lang=ratio_per_lang, + include=args.include.split(",") if args.include else None + ) + +if __name__ == "__main__": + main() diff --git a/samples_scores.py b/samples_scores.py new file mode 100644 index 0000000000000000000000000000000000000000..5b810577617a2d50595d91543a0e23742f181ef2 --- /dev/null +++ b/samples_scores.py @@ -0,0 +1,47 @@ +import pandas as pd +import argparse +import perplexity +from tqdm import tqdm +from joblib import Parallel, delayed +import os + + +def calculate_doc_perplexity(row_id, doc, doctype, lang): + id = doc["id"] + paragraphs = doc["paragraphs"] + doc_text = "\n".join(para["text"] for para in paragraphs) + perplexities = perplexity.source_perplexities(doc_text, lang, include_harmful=False) + return [id, doctype, lang, perplexities] + + +def main(args): + file = args.file + doctype = '_'.join(os.path.basename(file).split('_')[:-1]) + lang = os.path.basename(file).split('_')[-1].split('.')[0] + chunks = pd.read_json(file, lines=True, chunksize=1000) + rows = [] + for chunk in tqdm(chunks, desc=f"Processing chunks of {args.file}"): + rows.extend(Parallel(n_jobs=args.jobs)( + delayed(calculate_doc_perplexity)(row_id, doc, doctype, lang) + for row_id, doc in chunk.iterrows() + if doc["paragraphs"] + )) + df = pd.DataFrame(rows, columns=["id", "doc_type", "lang", "perplexities"]) + + # Ensure the output directory exists + os.makedirs(args.output_path, exist_ok=True) + + # Save the DataFrame with the ".jsonl" extension + output_file = os.path.join(args.output_path, f"{doctype}_{lang}.jsonl") + df.to_json(output_file, lines=True, orient="records") + + +if __name__ == "__main__": + # Set up argument parsing outside of the main() function + parser = argparse.ArgumentParser(description='Process documents to calculate perplexity.') + parser.add_argument('file', type=str, help='Input file path') + parser.add_argument('--output_path', type=str, default='tmp/', help='Output file path') + parser.add_argument('--jobs', type=int, default=10, help='Number of jobs to use for parallel processing') + args = parser.parse_args() + + main(args) diff --git a/scores/book_no.jsonl b/scores/book_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..4ec572805994fe35028935b5fde88b6b42ed8240 --- /dev/null +++ b/scores/book_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e8aae3dd33d686be5c3203781cbc6593f90f26ef199b8552e05b2e16d9719d1b +size 11325991 diff --git a/scores/books_pdf_no.jsonl b/scores/books_pdf_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..55b33d6b681dbc0b2a1e716e3862bc1393cc26f1 --- /dev/null +++ b/scores/books_pdf_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f424799588d6a9bace6751a3c0130716adbf601b8aa9d10469c1d8c187aaedc +size 70980 diff --git a/scores/culturax_da.jsonl b/scores/culturax_da.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..a8a5ce8f7e532bd90f65581f388ed7bdef124ad3 --- /dev/null +++ b/scores/culturax_da.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:076ce2540dea80d3b70a0ad6bf0b89c3f1dc8b8d8c7ae92d711f1811351c0d50 +size 186957751 diff --git a/scores/culturax_is.jsonl b/scores/culturax_is.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..9b29ab0b615a9dbe9bbedf9cfafe9ed3edc21d67 --- /dev/null +++ b/scores/culturax_is.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:059faa1d8facabde167e9b4a77fcf8ab759fb7f62b4f8b0525c9b9744b48ef00 +size 189711856 diff --git a/scores/culturax_nno.jsonl b/scores/culturax_nno.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..e0c74ef63b0fca1070cc582d8d11dfad2dfd6bbd --- /dev/null +++ b/scores/culturax_nno.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4203b071daa6ef2050b952ae936af65c9887617d2397c52b911993b6605f42c2 +size 23340707 diff --git a/scores/culturax_nob.jsonl b/scores/culturax_nob.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..664bca7cf7ded2cabf814f997ebbd507ad9a5dfd --- /dev/null +++ b/scores/culturax_nob.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a60a36a7dac765fd9aa3703503da74eba7d13420cb30bd69de105ac1b6cda231 +size 187737650 diff --git a/scores/culturax_sv.jsonl b/scores/culturax_sv.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..a613d4daca96fef9df91db8ed97857693f3ccb52 --- /dev/null +++ b/scores/culturax_sv.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:825bb6c7ebbe471f341477c232acf2f15f5435791813e0a296177f203753c11a +size 187913372 diff --git a/scores/digimanus_ocr_no.jsonl b/scores/digimanus_ocr_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..8da600889b6e0bb6ddc9cc4e81da3c45ce403408 --- /dev/null +++ b/scores/digimanus_ocr_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2edff4b2c531bd2fa92a48c8efdf4d9d4ec300b37e5bc6ff777993c5414b07c8 +size 936772 diff --git a/scores/evalueringsrapport_pdf_no.jsonl b/scores/evalueringsrapport_pdf_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..6b5e73444c4bfd49c6beae7a182361085c3fd3ac --- /dev/null +++ b/scores/evalueringsrapport_pdf_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70cd4683a5b8dc5d542bb7c50b4da19b55fddbe2e926ae6527420ee1ffbd24ca +size 621232 diff --git a/scores/hplt_da.jsonl b/scores/hplt_da.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..d10d76f3c54a71ca27183b0bc4ba6fe658f4a3b8 --- /dev/null +++ b/scores/hplt_da.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:29af33daae286d4c09ef40a4e7a8ba84e7945a77e08dba56f1ed3e865fae87a6 +size 171044178 diff --git a/scores/hplt_is.jsonl b/scores/hplt_is.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..86fed8579f060d5b3cffe911e876c5573730dafb --- /dev/null +++ b/scores/hplt_is.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:876a3bf83a742378658419c2653ed7c995e9458e065d9bd6792f556284d2bd96 +size 83002563 diff --git a/scores/hplt_nno.jsonl b/scores/hplt_nno.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..a40e54cc22d31ce477f2ed4cc0a9ae1b65784139 --- /dev/null +++ b/scores/hplt_nno.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7e3ed48eca8b1141ec51551133f0d7fa212689c59995455182f7778b96cd7206 +size 37998921 diff --git a/scores/hplt_nob.jsonl b/scores/hplt_nob.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..a2f85c8d3405ceff37dd626783729b5ab2382966 --- /dev/null +++ b/scores/hplt_nob.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c44c989a06146ef6c7de7f4fc728e20374e475dae8659638902c91a35beaee27 +size 171422934 diff --git a/scores/hplt_sv.jsonl b/scores/hplt_sv.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..2fe5da6ba7fab2d42f5a277267458d1eecfbfbfa --- /dev/null +++ b/scores/hplt_sv.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5e3f55974157c89177c1df8a51298b102c892c074e231ff070e95a80f3a35ae0 +size 171561680 diff --git a/scores/lovdata_cd_lokaleforskrifter_2005_no.jsonl b/scores/lovdata_cd_lokaleforskrifter_2005_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..7c8c9c1130fc251b24aa5ff64eabac8bf0b419c8 --- /dev/null +++ b/scores/lovdata_cd_lokaleforskrifter_2005_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ad69f5d1a04a6b17189bb100d3c5053872397f04b25686954893609b2a2f3cf8 +size 4621545 diff --git a/scores/lovdata_cd_norgeslover_2005_no.jsonl b/scores/lovdata_cd_norgeslover_2005_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..3309446990ef3e54783ad3a05201c88f35141ed9 --- /dev/null +++ b/scores/lovdata_cd_norgeslover_2005_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c3ca595f8cf74bcbcc1578c2aa898f9f6555af368f207ead11f8bc675a7c06a9 +size 276413 diff --git a/scores/lovdata_cd_odelsting_2005_no.jsonl b/scores/lovdata_cd_odelsting_2005_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..1fa2260380ad0d871d045cc390dc3c1a116d474f --- /dev/null +++ b/scores/lovdata_cd_odelsting_2005_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:83e64dbc1588a9105a9f3e06d44c803d4d61239a393267c1e5155c5ca6bcfec4 +size 377267 diff --git a/scores/lovdata_cd_rtv_rundskriv_2005_no.jsonl b/scores/lovdata_cd_rtv_rundskriv_2005_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..6d3037537845ddca8af3562341733e328361a06a --- /dev/null +++ b/scores/lovdata_cd_rtv_rundskriv_2005_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee95ad8b5b0cf98927ab6a6e3e003edc62ee875b6b278eeb859018623fbeca75 +size 1943796 diff --git a/scores/lovdata_cd_rundskriv_lovavdeling_2005_no.jsonl b/scores/lovdata_cd_rundskriv_lovavdeling_2005_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..b9740bccb85b4d706f33b2ad8f5392bc200b84a3 --- /dev/null +++ b/scores/lovdata_cd_rundskriv_lovavdeling_2005_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e028fd32b2fb1cb51643c7074cd9a4b30171ee8bc50e1019a5118cfe6fcfcb9c +size 85578 diff --git a/scores/lovdata_cd_sentrale_forskrifter_2005_no.jsonl b/scores/lovdata_cd_sentrale_forskrifter_2005_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..3e748ce6d8f6c7770e8170d64af22e6a932e34a5 --- /dev/null +++ b/scores/lovdata_cd_sentrale_forskrifter_2005_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:92a8bbb6bae41aeefb5ebe0ffc654e08a6a85aa9aa6b6fd188f275f6ae8ab8cd +size 2413077 diff --git a/scores/lovdata_cd_skatt_rundskriv_2005_no.jsonl b/scores/lovdata_cd_skatt_rundskriv_2005_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..4a20515b2c18b1454500bf544bdbd075e05c5e93 --- /dev/null +++ b/scores/lovdata_cd_skatt_rundskriv_2005_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af2b87178c2080d372b56a9ecea902dd91d70a1c36163add8fb33d4904eae990 +size 81870 diff --git a/scores/lovdata_cd_somb_rundskriv_2005_no.jsonl b/scores/lovdata_cd_somb_rundskriv_2005_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..7698b1eb786ac97beb7d7b246e37416be2ae0489 --- /dev/null +++ b/scores/lovdata_cd_somb_rundskriv_2005_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c045fe7d4f7d27c09a89870580c09def9e8ba82fe2f258a7a4ac3cbd42bc3b82 +size 652618 diff --git a/scores/maalfrid_crawl_doc_no.jsonl b/scores/maalfrid_crawl_doc_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..12e272d719e02858bec57f28dfdcf3a4f28e946b --- /dev/null +++ b/scores/maalfrid_crawl_doc_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e5cadd7d9ac1a480fdeab330b254c9786d63edd838f8df77206c00993ca0711f +size 3422190 diff --git a/scores/maalfrid_crawl_html_no.jsonl b/scores/maalfrid_crawl_html_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..3186cc5834b4e522526e3df35dbba05bb89ca73b --- /dev/null +++ b/scores/maalfrid_crawl_html_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:21f2cc083cac84e70b6c3d8118e70cc88f14cf5386fbfd16af8af53beb17c220 +size 159121552 diff --git a/scores/maalfrid_crawl_pdf_no.jsonl b/scores/maalfrid_crawl_pdf_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..1000894ea91b46b653de353bcd788f2698d73a6e --- /dev/null +++ b/scores/maalfrid_crawl_pdf_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bbdc15951e19457ca9e50c40179ca45e36d7f5b44fb2392efae5da4d4262115d +size 63212717 diff --git a/scores/maalfrid_government_html_no.jsonl b/scores/maalfrid_government_html_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..b4cd8427b7b25b7ed7a4fa734b9075c472925f99 --- /dev/null +++ b/scores/maalfrid_government_html_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2042e34eedaeca14df5c3dbd652a8967cd94d0fb3ba05b632f1d6ce24e2b5e4 +size 36810790 diff --git a/scores/maalfrid_government_pdf_no.jsonl b/scores/maalfrid_government_pdf_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..5f12caa0f8a52b253b84ecedada25ffa646379c6 --- /dev/null +++ b/scores/maalfrid_government_pdf_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:401a7bc73c5ba7bf19199ca3512fa02fc11c422cc92e11fbe53fec6984e80f28 +size 7431132 diff --git a/scores/newspaper_ocr_no.jsonl b/scores/newspaper_ocr_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..224affbc1d420324a1e8102fba270cab3ee144fe --- /dev/null +++ b/scores/newspaper_ocr_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0c2f5ea1ee8a1351256b30495f0e0df844c6727e788b98b23391267726eb985 +size 224844673 diff --git a/scores/newspaper_pdf_no.jsonl b/scores/newspaper_pdf_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..71253610fb69a3559210944186188f69054b8be1 --- /dev/null +++ b/scores/newspaper_pdf_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0ec18cc1849c04aa52c06c3e33e90a1473380c0338cf9d0ca94fc57d7544a89f +size 217335950 diff --git a/scores/newspapers_online_nno.jsonl b/scores/newspapers_online_nno.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..8c0355fb3734c0cb0da192a9160a7fc4957dc601 --- /dev/null +++ b/scores/newspapers_online_nno.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7ec898d11bbc90248b93a3857e53f9a6ce669ce5b7304570b8b994ee9870767b +size 32755280 diff --git a/scores/newspapers_online_nob.jsonl b/scores/newspapers_online_nob.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..5fecf42c72224e69b5b59bea19c7c08a4038baee --- /dev/null +++ b/scores/newspapers_online_nob.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:39f6aff94875e18cbcf165f2d2e4468bebc8fd75db421d70787dedef9776fd3d +size 198955993 diff --git a/scores/parlamint_xml_no.jsonl b/scores/parlamint_xml_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..e6a0d1c54977977ee78faa926bceededcd9cf310 --- /dev/null +++ b/scores/parlamint_xml_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:66a92226f8c9f0e6f38238809acb9dbafbb4e529928f4443639a1c4dcf3898fc +size 594737 diff --git a/scores/pg19_en.jsonl b/scores/pg19_en.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..88124691ce2a2b3e525ee0d0a0c4dba6085ee021 --- /dev/null +++ b/scores/pg19_en.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c77335a3a8dec5e52a442a45647ff8dfd71333458ff37dcd70c65250eb96f80 +size 4757688 diff --git a/scores/restricted-book_ocr_no.jsonl b/scores/restricted-book_ocr_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..f283fae421b30caafa76c176de310f53038738da --- /dev/null +++ b/scores/restricted-book_ocr_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:734e0052bb0b6fe841e2f0f71bc4ef57f92c245bec4b52e5c8df6b4d6baaad15 +size 17325447 diff --git a/scores/restricted-books_no.jsonl b/scores/restricted-books_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..c9890a1c7c6510f5fcc134575200dbd9a902a843 --- /dev/null +++ b/scores/restricted-books_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4922a90cf3edecb0d434ea8f345c2c9215b8b70c53f5c9042117914f66efb621 +size 17033509 diff --git a/scores/restricted-books_pdf_no.jsonl b/scores/restricted-books_pdf_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..23e3f20143e43b4084451b6022d82ef1ed186322 --- /dev/null +++ b/scores/restricted-books_pdf_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba1aa4a109ea6a78e36c79547a192a5c248b9b1360f8848486b7dd66c2491304 +size 3363926 diff --git a/scores/restricted-newspapers_mediafutures_amedia_no.jsonl b/scores/restricted-newspapers_mediafutures_amedia_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..5c1e30c6d134b3b46dd1d745d736261f8f1f1b0f --- /dev/null +++ b/scores/restricted-newspapers_mediafutures_amedia_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8614af91af9de6eac43193c8f34db47b6cbd0d1f05cedb2976922d86311d5e83 +size 210078668 diff --git a/scores/restricted-newspapers_mediafutures_schibsted_no.jsonl b/scores/restricted-newspapers_mediafutures_schibsted_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..3788b3b8c6eb8c997fbf5be4927263a6ed889390 --- /dev/null +++ b/scores/restricted-newspapers_mediafutures_schibsted_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:12927e7ee715db39a0323207629d98c996e2707ab1340efcd831c3b562369513 +size 179386324 diff --git a/scores/restricted-newspapers_mediafutures_tv2_no.jsonl b/scores/restricted-newspapers_mediafutures_tv2_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..92b7262676d22104af1a478b706cb2f85f27c60a --- /dev/null +++ b/scores/restricted-newspapers_mediafutures_tv2_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4e801d138a1a806f00feaa5e83f42e841c7fbae3a1d348e13a001fb421189461 +size 101505344 diff --git a/scores/restricted-newspapers_nrk_no.jsonl b/scores/restricted-newspapers_nrk_no.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..df18941f9e01a80d08da0787a4cf8d6f2928f765 --- /dev/null +++ b/scores/restricted-newspapers_nrk_no.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1f8ff1c2c7acb1be49f38a6bf5f562d855ba572717539ab0631e93bf8b472794 +size 17139876 diff --git a/scores/slimpajama_en.jsonl b/scores/slimpajama_en.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..fdcf47840ae8930a4954f0a2cd17ee75541c54c3 --- /dev/null +++ b/scores/slimpajama_en.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:31d8fc98d07df8709c57566f4cbbdd671d7349e476b1e38f899792343be172a8 +size 200105310 diff --git a/scores/starcoder_en.jsonl b/scores/starcoder_en.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..64a0a8ac105b6384fe4512524ca217144c4d2b11 --- /dev/null +++ b/scores/starcoder_en.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d8d45519d5a7f04792f4de807449c75fbf6060f97b2d4aeee47efc04606e8464 +size 218627579 diff --git a/scores/wikipedia_download_da.jsonl b/scores/wikipedia_download_da.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..ae1dca5f0f63191d99de9e699bdc49c504667fdb --- /dev/null +++ b/scores/wikipedia_download_da.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70c6a4599b756fd2a02f91d1fdfee1ea6b94698c5a123bd4052399fe0c6cfe07 +size 46169353 diff --git a/scores/wikipedia_download_en.jsonl b/scores/wikipedia_download_en.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..71a74b1b0e5a283a012680cd8638b889198dd293 --- /dev/null +++ b/scores/wikipedia_download_en.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:77b03e8806883dc1823fcda96c6fe410acef583caec0141aed1d06e05f3565d1 +size 222350696 diff --git a/scores/wikipedia_download_is.jsonl b/scores/wikipedia_download_is.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..cd76f597e7de790947661d173d36801d981f7b00 --- /dev/null +++ b/scores/wikipedia_download_is.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f0919ccb40b7e464e2846c32bca0e54a5195b8f82732bcfcd1844262cfa4515f +size 7854490 diff --git a/scores/wikipedia_download_nno.jsonl b/scores/wikipedia_download_nno.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..06cefc58786f8e90e2e4a86540aa96e10449fbe1 --- /dev/null +++ b/scores/wikipedia_download_nno.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a25f4a41214f0a49c8368219738d22d967bb22716547225a2512e20fd13358dd +size 23833551 diff --git a/scores/wikipedia_download_nob.jsonl b/scores/wikipedia_download_nob.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..6856aea86b09eb6a4d6215d598a732b966541187 --- /dev/null +++ b/scores/wikipedia_download_nob.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3205a3a3a6cb8db1388284638c04f48c6fa24f254455bf3b83a70e01f9090e7b +size 97647834 diff --git a/scores/wikipedia_download_sv.jsonl b/scores/wikipedia_download_sv.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..d46482da0b35f5c67857c60371e22436ad943472 --- /dev/null +++ b/scores/wikipedia_download_sv.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3c7ae93b0e13d8cb26ae77bd706895c566301e25d42730719432f37d85ca2ade +size 201159029 diff --git a/spm/books.norm.sp.model b/spm/books.norm.sp.model new file mode 100644 index 0000000000000000000000000000000000000000..f0f08b54f28e3275ad5e0970201b824f524880eb --- /dev/null +++ b/spm/books.norm.sp.model @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd8923187f09e14d2f81b76576effbc890469c624e4fb19293d51358d0b11cdb +size 1377561 diff --git a/spm/books.norm.sp.vocab b/spm/books.norm.sp.vocab new file mode 100644 index 0000000000000000000000000000000000000000..c1bb958cf39a0c75662e0a20b51654f6bdd222c3 --- /dev/null +++ b/spm/books.norm.sp.vocab @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4a72c1a45c42fc08499a99133a4ccff17a11337b72a774982d692f57478e5346 +size 1196708 diff --git a/spm/maalfrid.norm.sp.model b/spm/maalfrid.norm.sp.model new file mode 100644 index 0000000000000000000000000000000000000000..ccbf06c26d7ad69461b8d832663239dd46769ec7 --- /dev/null +++ b/spm/maalfrid.norm.sp.model @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1dc6b530ea9f4de8b53fa2dc19a546d13f644cf6df3fe435a3d24c6bbabe5adc +size 1550471 diff --git a/spm/maalfrid.norm.sp.vocab b/spm/maalfrid.norm.sp.vocab new file mode 100644 index 0000000000000000000000000000000000000000..37c1ec3992730668535a6bbc2e555ed6b882f1e6 --- /dev/null +++ b/spm/maalfrid.norm.sp.vocab @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d9d7372035c8cc51ab4a358787ffc9dc1275fc4eda4d856880f22e56a101bfd2 +size 1369413 diff --git a/spm/newspapers.norm.sp.model b/spm/newspapers.norm.sp.model new file mode 100644 index 0000000000000000000000000000000000000000..0454d3c17ad340993dd5adf85c8c2f4fc97e21df --- /dev/null +++ b/spm/newspapers.norm.sp.model @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c161b4447b463bfd69800ca361b00461ccff20a50c0a55fc70abb215de4c56f8 +size 1430229 diff --git a/spm/newspapers.norm.sp.vocab b/spm/newspapers.norm.sp.vocab new file mode 100644 index 0000000000000000000000000000000000000000..7d824a5c5b55f4ee39c0a91c223552b43a294ae6 --- /dev/null +++ b/spm/newspapers.norm.sp.vocab @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:21ddf1dafe3f9c4b6ed60f1635b021b1c11dd861a52324c68a199043c55fdba5 +size 1249458 diff --git a/spm/wikipedia/.keep b/spm/wikipedia/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/subsampler.py b/subsampler.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4501b2fed132dc2a79e0e871b056bc09cca457 --- /dev/null +++ b/subsampler.py @@ -0,0 +1,206 @@ +""" +PerplexitySubsampler class: define and execute subsampling on a dataset, +weighted by perplexity values +""" + +from collections import namedtuple +import numpy as np +import scipy as sp +from numpy.random import default_rng +from scipy.stats import norm, uniform + +from typing import List, Tuple, Iterable + +Histo = namedtuple("HISTO", "counts edges centers") + +rng = default_rng() + + +def histo_quantile(hcounts: np.ndarray, hedges: np.ndarray, + perc_values: Iterable[float]) -> List[float]: + """ + Compute quantile values by using a histogram + """ + cs = np.cumsum(hcounts)/np.sum(hcounts) + out = [] + for p in perc_values: + idx = np.searchsorted(cs, p) + frac = (p - cs[idx-1]) / (cs[idx] - cs[idx-1]) + r = hedges[idx] + (hedges[idx+1] - hedges[idx])*frac + out.append(r) + return out + + +def _histo_inv_quantile(hedges: np.ndarray, hcounts: np.ndarray, + perp_value: float) -> float: + """ + Using an histogram of values, estimate the quantile occupied + by a given value + It is therefore the inverse function of quantile() + """ + v = np.searchsorted(hedges, perp_value, side="right") + frac = (perp_value - hedges[v-1]) / (hedges[v] - hedges[v-1]) + return hcounts[:v-1].sum() + hcounts[v-1]*frac + + +def subsample_frac(data: np.ndarray, frac: float) -> np.ndarray: + """ + Subsample an array to a given fraction + """ + return data[uniform.rvs(size=len(data)) < frac] + + + +# ------------------------------------------------------------------------- + + +class PerplexitySubsampler: + + def __init__(self, perp_data: np.ndarray = None, + perp_histogram: Tuple[np.ndarray, np.ndarray] = None, + hbins: int = 1000): + """ + :param perp_data: a dataset of perplexity values + :param perp_histo: a histogram computed over a dataset of perplexity + values, passed as a tuple (counts, edges) + :param hbins: number of bins to use for the histogram approximation + (only used if `perp_data` is passed) + + Either `perp_data` or `perp_histogram` must be passed + """ + if perp_data is not None: + + # Get the P25 and P75 quartiles + self.qr = np.quantile(perp_data, [0.25, 0.75]) + # Build an histogram of perplexities + range_max = self.qr[1]*10 + counts, edges = np.histogram(perp_data, bins=hbins, + range=[0, range_max]) + counts[-1] += len(perp_data[perp_data > range_max]) + self.histo = Histo(counts, edges, (edges[:-1] + edges[1:])/2) + + elif perp_histogram is not None: + + edges = perp_histogram[1] + self.histo = Histo(perp_histogram[0], edges, + (edges[:-1] + edges[1:])/2) + self.qr = histo_quantile(self.histo.counts, self.histo.edges, + [0.25, 0.75]) + + else: + raise Exception("Neither sample nor histogram provided") + + + def _estimate(self, m: float, s: float, + ratio: float) -> Tuple[float, float]: + """ + Estimate the quantiles to be retained in the 1st & 4th original + quartiles + """ + # Compute the normalization factor + gauss_weights = norm.pdf(self.histo.centers, loc=m, scale=s) + hcounts = self.histo.counts + adjusted_norm = (hcounts*gauss_weights).sum()/hcounts.sum()/ratio + # Subsample the histogram + hcounts_sub = self.histo.counts*gauss_weights/adjusted_norm + sub_size = hcounts_sub.sum() + # Estimate the quantiles at Xa & Xb + ra = _histo_inv_quantile(self.histo.edges, hcounts_sub, self.qr[0])/sub_size + rb = _histo_inv_quantile(self.histo.edges, hcounts_sub, self.qr[1])/sub_size + #print(f"{m:10.2f} {s:10.2f} => {ra:.4} {1-rb:.4}") + return ra, 1-rb + + + def _error(self, point: np.ndarray, ratio: float, + pa: float, pb: float) -> float: + """ + Estimate the error in probability mass results + """ + actual_pa, actual_pb = self._estimate(point[0], point[1], ratio) + return abs(pa-actual_pa) + abs(pb-actual_pb) + + + def set(self, ratio: float, pa: float, pb: float): + """ + Compute the parameters needed to achieve a desired sampling ratio & + probability distribution + :param ratio: the desired sampling ratio + :param pa: the probability mass to be left in the first original + perplexity quartile + :param pb: the probability mass to be left in the fourth original + perplexity quartile + """ + # Obtain the initial parameters for the gaussian weighting function + # (assuming uniform data) + sdev = (self.qr[0] - self.qr[1]) / (norm.ppf(pa) - norm.ppf(1-pb)) + mean = self.qr[0] - norm.ppf(pa)*sdev + # Optimize for the real data distribution + initial = np.array([mean, sdev]) + result = sp.optimize.minimize(self._error, initial, + args=(ratio, pa, pb), + method='nelder-mead', + options={'xatol': 1e-8, 'disp': False}) + self.mean, self.sdev = result.x + # Now that we have the final parameters, compute the weighting + # function over the histogram values + gauss_weights = norm.pdf(self.histo.centers, loc=self.mean, + scale=self.sdev) + # Find the normalization needed to achieve the desired sampling ratio + counts = self.histo.counts + self.norm = (counts*gauss_weights).sum()/counts.sum()/ratio + + + def subsample(self, data: np.ndarray) -> np.ndarray: + """ + Subsample a dataset according to the defined conditions + Note: set() must have been called previously + """ + # Create the gaussian weight for each data point + p = norm.pdf(data, loc=self.mean, scale=self.sdev)/self.norm + #print(p) + # Subsample data with probability according to the weight + return data[uniform.rvs(size=len(p)) < p] + + + def retain(self, perp: float) -> bool: + """ + Decide if a sample is to be retained based on its perplexity value + Note: set() must have been called previously + """ + p = norm.pdf(perp, loc=self.mean, scale=self.sdev)/self.norm + return rng.uniform() < p + + + def subsample_piecewise(self, data: np.ndarray, + pa: float, pb: float) -> np.ndarray: + """ + Creat a subsample by directly subsampling each region + """ + qr = self.qr + data1 = subsample_frac(data[data < qr[0]], pa/0.25*self.ratio) + data2 = subsample_frac(data[(data >= qr[0]) & (data <= qr[1])], + (1-pa-pb)/0.5*self.ratio) + data3 = subsample_frac(data[self.data > qr[1]], pb/0.25*self.ratio) + return np.hstack([data1, data2, data3]) + + + def verify(self, data: np.ndarray, data_sub: np.ndarray) -> Tuple: + """ + Check the statistics of a sample + """ + ratio = len(data_sub)/len(data) + ra = len(data_sub[data_sub < self.qr[0]]) / len(data_sub) + rb = len(data_sub[data_sub > self.qr[1]]) / len(data_sub) + return ratio, ra, rb + + + +def check_results(s: PerplexitySubsampler, + data_full: np.ndarray, data_sub: np.ndarray): + """ + Compute and print out the results for a subsample + """ + r, ra, rb = s.verify(data_full, data_sub) + print("Sampling ratio:", r) + print("Probability mass below Pa:", ra) + print("Probability mass above Pb:", rb) diff --git a/texts/.keep b/texts/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/train_all.sh b/train_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..46c6530df63fd759167313ababb81b5474aea455 --- /dev/null +++ b/train_all.sh @@ -0,0 +1,20 @@ +python normalization.py /nfsmounts/datastore/ncc_corpus/extract/newspaper.txt newspapers.norm.txt --cutoff 1000000 +python normalization.py /nfsmounts/datastore/ncc_corpus/extract/booktexts.txt books.norm.txt --cutoff 10000000 + +~/bin/spm_train --input newspapers.norm.txt --vocab_size 64000 --character_coverage=0.99 --input_sentence_size 1000000 --shuffle_input_sentence true --model_prefix newspapers.norm.sp --max_sentence_length 1000000000 --num_threads 64 +~/bin/spm_train --input books.norm.txt --vocab_size 64000 --character_coverage=0.99 --input_sentence_size 1000000 --shuffle_input_sentence true --model_prefix books.norm.sp --max_sentence_length 1000000000 --num_threads 64 + +~/bin/spm_encode --model newspapers.norm.sp.model --output_format=piece newspapers.norm.txt > newspapers.norm.sp.txt +~/bin/spm_encode --model books.norm.sp.model --output_format=piece books.norm.txt > books.norm.sp.txt + +~/bin/lmplz -o 5 -S 75% -T tmp --vocab_estimate 64000 --discount_fallback --skip_symbols < newspapers.norm.txt > newspapers.norm.arpa +~/bin/lmplz -o 5 -S 75% -T tmp --vocab_estimate 64000 --discount_fallback < newspapers.norm.sp.txt > newspapers.norm.sp.arpa + +~/bin/lmplz -o 5 -S 75% -T tmp --vocab_estimate 64000 --discount_fallback --skip_symbols < books.norm.txt > books.norm.arpa +~/bin/lmplz -o 5 -S 75% -T tmp --vocab_estimate 64000 --discount_fallback < books.norm.sp.txt > books.norm.sp.arpa + +~/bin/build_binary newspapers.norm.arpa > newspapers.norm.arpa.bin +~/bin/build_binary newspapers.norm.sp.arpa > newspapers.norm.sp.arpa.bin + +~/bin/build_binary books.norm.arpa > books.norm.arpa.bin +~/bin/build_binary books.norm.sp.arpa > books.norm.sp.arpa.bin