diff --git a/flax_model.msgpack b/flax_model.msgpack index d7b3172a705c066aaf5eb63e89a07ecda3789e09..06e8410245f53b24f15363c853c91bd1f34ccfbe 100644 --- a/flax_model.msgpack +++ b/flax_model.msgpack @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:393c37966461709fe51a3b3f84befb7fa7e5030025856d171308efd40dbbc7da +oid sha256:3362fadd5306539e1775600df59349acf0fd260751129abaa45b997091d089ce size 249750019 diff --git a/mc4/mc4.py b/mc4/mc4.py index 9b5f7e6c1014f254c202cdd6c161bc3e598f245a..26aea80074696ec4231c992b4d71d7b63da198d1 100644 --- a/mc4/mc4.py +++ b/mc4/mc4.py @@ -1,11 +1,11 @@ -"""mC4 dataset based on Common Crawl.""" +"""Perplexity Sampled mC4 dataset based on Common Crawl.""" import gzip import json import datasets -import kenlm +import kenlm # pip install https://github.com/kpu/kenlm/archive/master.zip import numpy as np from numpy.random import default_rng @@ -289,6 +289,7 @@ class Mc4(datasets.GeneratorBasedBuilder): self.sampling_factor = kwargs.pop("sampling_factor", None) self.boundaries = kwargs.pop("boundaries", None) self.seed = kwargs.pop("seed", None) + self.kwargs = kwargs if self.sampling_method: if self.seed is not None: self.rng = default_rng(self.seed) @@ -316,7 +317,7 @@ class Mc4(datasets.GeneratorBasedBuilder): doc_length += length return 10.0 ** (-doc_log_score / doc_length) - def _should_keep_doc_step(self, doc, factor=1.5e5, boundaries=None): + def _should_keep_doc_step(self, doc, factor=1.5e5, boundaries=None, **kwargs): perplexity = self.get_perplexity(doc) if boundaries is None: boundaries = [536394.99320948, 662247.50212365, 919250.87225178] @@ -331,17 +332,18 @@ class Mc4(datasets.GeneratorBasedBuilder): probability = factor / quartile_range return self.rng.uniform() < probability - def _should_keep_doc_gaussian(self, doc, factor=0.78, boundaries=None): + def _should_keep_doc_gaussian(self, doc, factor=0.78, boundaries=None, **kwargs): + width = kwargs.get("width", 9 / 2) # width (spread) of the exponential curve perplexity = self.get_perplexity(doc) if boundaries is not None: m = boundaries[1] else: m = 662247.50212365 - exponential = np.exp(-9/2 * ((perplexity - m) / m) ** 2) + exponential = np.exp((-1 / width) * ((perplexity - m) / m) ** 2) weighted_perplexity = factor * exponential return self.rng.uniform() < weighted_perplexity - def _should_keep_doc_random(self, doc, factor=None, boundaries=None): + def _should_keep_doc_random(self, doc, factor=None, boundaries=None, **kwargs): if factor is None: factor = 0.5 return self.rng.uniform() <= factor @@ -374,13 +376,13 @@ class Mc4(datasets.GeneratorBasedBuilder): for lang in self.config.languages for index in range(_N_SHARDS_PER_SPLIT[lang][split]) ] - if "train" in self.data_files: + if self.data_files and "train" in self.data_files: train_downloaded_files = self.data_files["train"] if not isinstance(train_downloaded_files, (tuple, list)): train_downloaded_files = [train_downloaded_files] else: train_downloaded_files = dl_manager.download(data_urls["train"]) - if "validation" in self.data_files: + if self.data_files and "validation" in self.data_files: validation_downloaded_files = self.data_files["validation"] if not isinstance(validation_downloaded_files, (tuple, list)): validation_downloaded_files = [validation_downloaded_files] @@ -415,7 +417,8 @@ class Mc4(datasets.GeneratorBasedBuilder): if self.should_keep_doc( example["text"], factor=self.sampling_factor, - boundaries=self.boundaries): + boundaries=self.boundaries, + **self.kwargs): yield id_, example id_ += 1 else: diff --git a/outputs/checkpoints/checkpoint-140001/data_collator.joblib b/outputs/checkpoints/checkpoint-140001/data_collator.joblib deleted file mode 100644 index 3c029096ac0457bdb70428bf5341d77100a030c6..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-140001/data_collator.joblib +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e02a6e9cfa63cb321cac9402efd29841b652999fcbf787800ae050e747b161ee -size 1471394 diff --git a/outputs/checkpoints/checkpoint-140001/flax_model.msgpack b/outputs/checkpoints/checkpoint-140001/flax_model.msgpack deleted file mode 100644 index 2e86ead3d38c0a130e61a7eacc25b746df8b284d..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-140001/flax_model.msgpack +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:bb3b6443b0b4e0fd6b95f7409525ddde51fb73dd99318041f2fecda9f547f5a6 -size 249750019 diff --git a/outputs/checkpoints/checkpoint-140001/optimizer_state.msgpack b/outputs/checkpoints/checkpoint-140001/optimizer_state.msgpack deleted file mode 100644 index 750c3998f9fdaa3e1372b62590d0f57c507f83d8..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-140001/optimizer_state.msgpack +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:73ce4d1287008fdfac801ca7df44a0debe3e41f901970f3132f0cd49d2ad6bd0 -size 499500278 diff --git a/outputs/checkpoints/checkpoint-140001/training_args.joblib b/outputs/checkpoints/checkpoint-140001/training_args.joblib deleted file mode 100644 index 9ec9d51e87d8b26b07ba1d4300d451a9dfc4a8d8..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-140001/training_args.joblib +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8bc14fe16573d318dd510c7cfb42ebb7cc87b4dcf77e99247a2d1605cffd772b -size 1876 diff --git a/outputs/checkpoints/checkpoint-140001/training_state.json b/outputs/checkpoints/checkpoint-140001/training_state.json deleted file mode 100644 index 9f1d639ca3e47df4f853733cf007282ba99b130b..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-140001/training_state.json +++ /dev/null @@ -1 +0,0 @@ -{"step": 140001} \ No newline at end of file diff --git a/outputs/checkpoints/checkpoint-150001/data_collator.joblib b/outputs/checkpoints/checkpoint-150001/data_collator.joblib deleted file mode 100644 index 3c029096ac0457bdb70428bf5341d77100a030c6..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-150001/data_collator.joblib +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e02a6e9cfa63cb321cac9402efd29841b652999fcbf787800ae050e747b161ee -size 1471394 diff --git a/outputs/checkpoints/checkpoint-150001/flax_model.msgpack b/outputs/checkpoints/checkpoint-150001/flax_model.msgpack deleted file mode 100644 index 1d6d488c4254073cee3e6e5dc14aa974bcb0892f..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-150001/flax_model.msgpack +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d9f2a38ac6c111d01809dd28ae9078aab932064126a7de753ce0d88bd60421e4 -size 249750019 diff --git a/outputs/checkpoints/checkpoint-150001/optimizer_state.msgpack b/outputs/checkpoints/checkpoint-150001/optimizer_state.msgpack deleted file mode 100644 index 793caa05259c0c7cd35a254079a0ee36b8118cb5..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-150001/optimizer_state.msgpack +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:84f53f9b574ccfb97696f637d71903b9762ef2718c656bea201e5aeb9078c328 -size 499500278 diff --git a/outputs/checkpoints/checkpoint-150001/training_args.joblib b/outputs/checkpoints/checkpoint-150001/training_args.joblib deleted file mode 100644 index 9ec9d51e87d8b26b07ba1d4300d451a9dfc4a8d8..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-150001/training_args.joblib +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8bc14fe16573d318dd510c7cfb42ebb7cc87b4dcf77e99247a2d1605cffd772b -size 1876 diff --git a/outputs/checkpoints/checkpoint-150001/training_state.json b/outputs/checkpoints/checkpoint-150001/training_state.json deleted file mode 100644 index 9aa4ddd23663437dd3e5b6783fbcf664680bd421..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-150001/training_state.json +++ /dev/null @@ -1 +0,0 @@ -{"step": 150001} \ No newline at end of file diff --git a/outputs/checkpoints/checkpoint-160001/data_collator.joblib b/outputs/checkpoints/checkpoint-160001/data_collator.joblib deleted file mode 100644 index 3c029096ac0457bdb70428bf5341d77100a030c6..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-160001/data_collator.joblib +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e02a6e9cfa63cb321cac9402efd29841b652999fcbf787800ae050e747b161ee -size 1471394 diff --git a/outputs/checkpoints/checkpoint-160001/flax_model.msgpack b/outputs/checkpoints/checkpoint-160001/flax_model.msgpack deleted file mode 100644 index 80c06d3cb54c1b4843e851835929a37bcb21ded6..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-160001/flax_model.msgpack +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b86d26169d8fb7bb58ae7fecd67ca557a0affc93bf2d5b5947af0070ee894ab9 -size 249750019 diff --git a/outputs/checkpoints/checkpoint-160001/optimizer_state.msgpack b/outputs/checkpoints/checkpoint-160001/optimizer_state.msgpack deleted file mode 100644 index c1cb0ea046ed5f5c29419eb313867c37c113f140..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-160001/optimizer_state.msgpack +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ea3a8f65ea9c3c6c3606f1167c4e54049784fa8b2a5ee3f4936563ecd4f811b6 -size 499500278 diff --git a/outputs/checkpoints/checkpoint-160001/training_args.joblib b/outputs/checkpoints/checkpoint-160001/training_args.joblib deleted file mode 100644 index 9ec9d51e87d8b26b07ba1d4300d451a9dfc4a8d8..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-160001/training_args.joblib +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8bc14fe16573d318dd510c7cfb42ebb7cc87b4dcf77e99247a2d1605cffd772b -size 1876 diff --git a/outputs/checkpoints/checkpoint-160001/training_state.json b/outputs/checkpoints/checkpoint-160001/training_state.json deleted file mode 100644 index 96dad52987058bcfff4a97ce4dc762b4a8b6efa6..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-160001/training_state.json +++ /dev/null @@ -1 +0,0 @@ -{"step": 160001} \ No newline at end of file diff --git a/outputs/checkpoints/checkpoint-170001/data_collator.joblib b/outputs/checkpoints/checkpoint-170001/data_collator.joblib deleted file mode 100644 index 3c029096ac0457bdb70428bf5341d77100a030c6..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-170001/data_collator.joblib +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e02a6e9cfa63cb321cac9402efd29841b652999fcbf787800ae050e747b161ee -size 1471394 diff --git a/outputs/checkpoints/checkpoint-170001/flax_model.msgpack b/outputs/checkpoints/checkpoint-170001/flax_model.msgpack deleted file mode 100644 index 9f79892d6f273785dac02b2912ba6efa687ab56d..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-170001/flax_model.msgpack +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c40291527e2cf6e418cf78bb9cd4eec53ac716230987ad7a0a447bf0ce041d4c -size 249750019 diff --git a/outputs/checkpoints/checkpoint-170001/optimizer_state.msgpack b/outputs/checkpoints/checkpoint-170001/optimizer_state.msgpack deleted file mode 100644 index 5d31b7814c1669cf3690031f6da7a278a449ec54..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-170001/optimizer_state.msgpack +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:90dbe4fe7d7694dd86d13e9b075953620aa4dabb4fdc2023b6ede17aa720848e -size 499500278 diff --git a/outputs/checkpoints/checkpoint-170001/training_args.joblib b/outputs/checkpoints/checkpoint-170001/training_args.joblib deleted file mode 100644 index 9ec9d51e87d8b26b07ba1d4300d451a9dfc4a8d8..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-170001/training_args.joblib +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8bc14fe16573d318dd510c7cfb42ebb7cc87b4dcf77e99247a2d1605cffd772b -size 1876 diff --git a/outputs/checkpoints/checkpoint-170001/training_state.json b/outputs/checkpoints/checkpoint-170001/training_state.json deleted file mode 100644 index 52194acb9d196ce7a9f7a284390fb10f7174373c..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-170001/training_state.json +++ /dev/null @@ -1 +0,0 @@ -{"step": 170001} \ No newline at end of file diff --git a/outputs/checkpoints/checkpoint-180001/data_collator.joblib b/outputs/checkpoints/checkpoint-180001/data_collator.joblib deleted file mode 100644 index 3c029096ac0457bdb70428bf5341d77100a030c6..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-180001/data_collator.joblib +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e02a6e9cfa63cb321cac9402efd29841b652999fcbf787800ae050e747b161ee -size 1471394 diff --git a/outputs/checkpoints/checkpoint-180001/flax_model.msgpack b/outputs/checkpoints/checkpoint-180001/flax_model.msgpack deleted file mode 100644 index d7b3172a705c066aaf5eb63e89a07ecda3789e09..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-180001/flax_model.msgpack +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:393c37966461709fe51a3b3f84befb7fa7e5030025856d171308efd40dbbc7da -size 249750019 diff --git a/outputs/checkpoints/checkpoint-180001/optimizer_state.msgpack b/outputs/checkpoints/checkpoint-180001/optimizer_state.msgpack deleted file mode 100644 index 9ecf7634d332666cefc80f7f45d1d7c2245deb9a..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-180001/optimizer_state.msgpack +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:3a33cad417a7e78eaafc1c041f93fd54ad9f63869d01e1351bac4abcd58e4eeb -size 499500278 diff --git a/outputs/checkpoints/checkpoint-180001/training_args.joblib b/outputs/checkpoints/checkpoint-180001/training_args.joblib deleted file mode 100644 index 9ec9d51e87d8b26b07ba1d4300d451a9dfc4a8d8..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-180001/training_args.joblib +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8bc14fe16573d318dd510c7cfb42ebb7cc87b4dcf77e99247a2d1605cffd772b -size 1876 diff --git a/outputs/checkpoints/checkpoint-180001/training_state.json b/outputs/checkpoints/checkpoint-180001/training_state.json deleted file mode 100644 index 75348a914ed8eb10df17cc2276ac75eb02bd2c8c..0000000000000000000000000000000000000000 --- a/outputs/checkpoints/checkpoint-180001/training_state.json +++ /dev/null @@ -1 +0,0 @@ -{"step": 180001} \ No newline at end of file diff --git a/outputs/checkpoints/checkpoint-140001/config.json b/outputs/checkpoints/checkpoint-182000/config.json similarity index 100% rename from outputs/checkpoints/checkpoint-140001/config.json rename to outputs/checkpoints/checkpoint-182000/config.json diff --git a/outputs/checkpoints/checkpoint-182000/data_collator.joblib b/outputs/checkpoints/checkpoint-182000/data_collator.joblib new file mode 100644 index 0000000000000000000000000000000000000000..4309dae74a6f7a811944b2d931b89db770fd4b70 --- /dev/null +++ b/outputs/checkpoints/checkpoint-182000/data_collator.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0321b1a9629e1be122045cd72470365a63c8496fec109fdeec34827f01ffbb9e +size 1471424 diff --git a/outputs/checkpoints/checkpoint-182000/flax_model.msgpack b/outputs/checkpoints/checkpoint-182000/flax_model.msgpack new file mode 100644 index 0000000000000000000000000000000000000000..78dd434977ab6d9ebf8bb6a09d2c8f80c2e692ed --- /dev/null +++ b/outputs/checkpoints/checkpoint-182000/flax_model.msgpack @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:73b385f39d585a15aeee595103523252872a4277da8c21c9bc9bbfd0232a72c9 +size 249750019 diff --git a/outputs/checkpoints/checkpoint-182000/optimizer_state.msgpack b/outputs/checkpoints/checkpoint-182000/optimizer_state.msgpack new file mode 100644 index 0000000000000000000000000000000000000000..8d43cc93757956823beeaf27893557ac3e987185 --- /dev/null +++ b/outputs/checkpoints/checkpoint-182000/optimizer_state.msgpack @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a6c074f7bc1330f514d95eb6277f0232abc85b0ec2f8c1c5972fd46af01985d4 +size 499500278 diff --git a/outputs/checkpoints/checkpoint-182000/training_args.joblib b/outputs/checkpoints/checkpoint-182000/training_args.joblib new file mode 100644 index 0000000000000000000000000000000000000000..5855f72bcc4215a2b17083c6e9cce26965c419b6 --- /dev/null +++ b/outputs/checkpoints/checkpoint-182000/training_args.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:79196d5f797a7527e60287e48f716fd4626c5d57186661d6a214a2027998c86d +size 1873 diff --git a/outputs/checkpoints/checkpoint-182000/training_state.json b/outputs/checkpoints/checkpoint-182000/training_state.json new file mode 100644 index 0000000000000000000000000000000000000000..82702578accb9b599472b105f896a38b0edba959 --- /dev/null +++ b/outputs/checkpoints/checkpoint-182000/training_state.json @@ -0,0 +1 @@ +{"step": 182001} \ No newline at end of file diff --git a/outputs/checkpoints/checkpoint-150001/config.json b/outputs/checkpoints/checkpoint-183000/config.json similarity index 100% rename from outputs/checkpoints/checkpoint-150001/config.json rename to outputs/checkpoints/checkpoint-183000/config.json diff --git a/outputs/checkpoints/checkpoint-183000/data_collator.joblib b/outputs/checkpoints/checkpoint-183000/data_collator.joblib new file mode 100644 index 0000000000000000000000000000000000000000..4309dae74a6f7a811944b2d931b89db770fd4b70 --- /dev/null +++ b/outputs/checkpoints/checkpoint-183000/data_collator.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0321b1a9629e1be122045cd72470365a63c8496fec109fdeec34827f01ffbb9e +size 1471424 diff --git a/outputs/checkpoints/checkpoint-183000/flax_model.msgpack b/outputs/checkpoints/checkpoint-183000/flax_model.msgpack new file mode 100644 index 0000000000000000000000000000000000000000..c51d44653664e633033cef88780ffb07fc6508af --- /dev/null +++ b/outputs/checkpoints/checkpoint-183000/flax_model.msgpack @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0059f1c021d4be233f41fc29384c42e8844d69c44e1d31806620a38d76d6aa00 +size 249750019 diff --git a/outputs/checkpoints/checkpoint-183000/optimizer_state.msgpack b/outputs/checkpoints/checkpoint-183000/optimizer_state.msgpack new file mode 100644 index 0000000000000000000000000000000000000000..53da142a30034e221bbf8dd4c6a4822a4f6f1e32 --- /dev/null +++ b/outputs/checkpoints/checkpoint-183000/optimizer_state.msgpack @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e99296d552d436c722721d611be2ed7cf4862c4b6a40abbc3464f7a75ca25c58 +size 499500278 diff --git a/outputs/checkpoints/checkpoint-183000/training_args.joblib b/outputs/checkpoints/checkpoint-183000/training_args.joblib new file mode 100644 index 0000000000000000000000000000000000000000..5855f72bcc4215a2b17083c6e9cce26965c419b6 --- /dev/null +++ b/outputs/checkpoints/checkpoint-183000/training_args.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:79196d5f797a7527e60287e48f716fd4626c5d57186661d6a214a2027998c86d +size 1873 diff --git a/outputs/checkpoints/checkpoint-183000/training_state.json b/outputs/checkpoints/checkpoint-183000/training_state.json new file mode 100644 index 0000000000000000000000000000000000000000..1a76b28408c6bbfc2fed1318d92193dd4472494b --- /dev/null +++ b/outputs/checkpoints/checkpoint-183000/training_state.json @@ -0,0 +1 @@ +{"step": 183001} \ No newline at end of file diff --git a/outputs/checkpoints/checkpoint-160001/config.json b/outputs/checkpoints/checkpoint-184000/config.json similarity index 100% rename from outputs/checkpoints/checkpoint-160001/config.json rename to outputs/checkpoints/checkpoint-184000/config.json diff --git a/outputs/checkpoints/checkpoint-184000/data_collator.joblib b/outputs/checkpoints/checkpoint-184000/data_collator.joblib new file mode 100644 index 0000000000000000000000000000000000000000..4309dae74a6f7a811944b2d931b89db770fd4b70 --- /dev/null +++ b/outputs/checkpoints/checkpoint-184000/data_collator.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0321b1a9629e1be122045cd72470365a63c8496fec109fdeec34827f01ffbb9e +size 1471424 diff --git a/outputs/checkpoints/checkpoint-184000/flax_model.msgpack b/outputs/checkpoints/checkpoint-184000/flax_model.msgpack new file mode 100644 index 0000000000000000000000000000000000000000..8e7ba961f5adf5e4af041d5828a2edfdf87bac3d --- /dev/null +++ b/outputs/checkpoints/checkpoint-184000/flax_model.msgpack @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5063b47bd0228e701c1cbc12f4ce305c577868d34589323f88eb3ca2057b322d +size 249750019 diff --git a/outputs/checkpoints/checkpoint-184000/optimizer_state.msgpack b/outputs/checkpoints/checkpoint-184000/optimizer_state.msgpack new file mode 100644 index 0000000000000000000000000000000000000000..c7b42c6d68d88ac70c65eb7afb435c967a45b88b --- /dev/null +++ b/outputs/checkpoints/checkpoint-184000/optimizer_state.msgpack @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:03b1a223244cdeec59f4a501edcc3e29d75fa5a7467d514c3c30d9a2fcb49401 +size 499500278 diff --git a/outputs/checkpoints/checkpoint-184000/training_args.joblib b/outputs/checkpoints/checkpoint-184000/training_args.joblib new file mode 100644 index 0000000000000000000000000000000000000000..5855f72bcc4215a2b17083c6e9cce26965c419b6 --- /dev/null +++ b/outputs/checkpoints/checkpoint-184000/training_args.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:79196d5f797a7527e60287e48f716fd4626c5d57186661d6a214a2027998c86d +size 1873 diff --git a/outputs/checkpoints/checkpoint-184000/training_state.json b/outputs/checkpoints/checkpoint-184000/training_state.json new file mode 100644 index 0000000000000000000000000000000000000000..84c6e6ebd4ce4832410cb709b7060ba70fab1eff --- /dev/null +++ b/outputs/checkpoints/checkpoint-184000/training_state.json @@ -0,0 +1 @@ +{"step": 184001} \ No newline at end of file diff --git a/outputs/checkpoints/checkpoint-170001/config.json b/outputs/checkpoints/checkpoint-185000/config.json similarity index 100% rename from outputs/checkpoints/checkpoint-170001/config.json rename to outputs/checkpoints/checkpoint-185000/config.json diff --git a/outputs/checkpoints/checkpoint-185000/data_collator.joblib b/outputs/checkpoints/checkpoint-185000/data_collator.joblib new file mode 100644 index 0000000000000000000000000000000000000000..4309dae74a6f7a811944b2d931b89db770fd4b70 --- /dev/null +++ b/outputs/checkpoints/checkpoint-185000/data_collator.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0321b1a9629e1be122045cd72470365a63c8496fec109fdeec34827f01ffbb9e +size 1471424 diff --git a/outputs/checkpoints/checkpoint-185000/flax_model.msgpack b/outputs/checkpoints/checkpoint-185000/flax_model.msgpack new file mode 100644 index 0000000000000000000000000000000000000000..7d8f52acb5fa5b1aa5bfe2e376cd0d41cd7c7dc6 --- /dev/null +++ b/outputs/checkpoints/checkpoint-185000/flax_model.msgpack @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:53e61bbf07baa76be1c67dcb57fd59f87d0171d7ce5ebdff16e054e5e4c547fe +size 249750019 diff --git a/outputs/checkpoints/checkpoint-185000/optimizer_state.msgpack b/outputs/checkpoints/checkpoint-185000/optimizer_state.msgpack new file mode 100644 index 0000000000000000000000000000000000000000..e09acaa3b8bf177f102d470e65d1e959aacdddec --- /dev/null +++ b/outputs/checkpoints/checkpoint-185000/optimizer_state.msgpack @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1a2a3d946f859ac106a7ede9530e200364d0850a98d8c958672d441c57ab89bd +size 499500278 diff --git a/outputs/checkpoints/checkpoint-185000/training_args.joblib b/outputs/checkpoints/checkpoint-185000/training_args.joblib new file mode 100644 index 0000000000000000000000000000000000000000..5855f72bcc4215a2b17083c6e9cce26965c419b6 --- /dev/null +++ b/outputs/checkpoints/checkpoint-185000/training_args.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:79196d5f797a7527e60287e48f716fd4626c5d57186661d6a214a2027998c86d +size 1873 diff --git a/outputs/checkpoints/checkpoint-185000/training_state.json b/outputs/checkpoints/checkpoint-185000/training_state.json new file mode 100644 index 0000000000000000000000000000000000000000..dc73d846c65476bcde1bab8398f46158fd9b0ec9 --- /dev/null +++ b/outputs/checkpoints/checkpoint-185000/training_state.json @@ -0,0 +1 @@ +{"step": 185001} \ No newline at end of file diff --git a/outputs/checkpoints/checkpoint-180001/config.json b/outputs/checkpoints/checkpoint-186000/config.json similarity index 100% rename from outputs/checkpoints/checkpoint-180001/config.json rename to outputs/checkpoints/checkpoint-186000/config.json diff --git a/outputs/checkpoints/checkpoint-186000/data_collator.joblib b/outputs/checkpoints/checkpoint-186000/data_collator.joblib new file mode 100644 index 0000000000000000000000000000000000000000..4309dae74a6f7a811944b2d931b89db770fd4b70 --- /dev/null +++ b/outputs/checkpoints/checkpoint-186000/data_collator.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0321b1a9629e1be122045cd72470365a63c8496fec109fdeec34827f01ffbb9e +size 1471424 diff --git a/outputs/checkpoints/checkpoint-186000/flax_model.msgpack b/outputs/checkpoints/checkpoint-186000/flax_model.msgpack new file mode 100644 index 0000000000000000000000000000000000000000..06e8410245f53b24f15363c853c91bd1f34ccfbe --- /dev/null +++ b/outputs/checkpoints/checkpoint-186000/flax_model.msgpack @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3362fadd5306539e1775600df59349acf0fd260751129abaa45b997091d089ce +size 249750019 diff --git a/outputs/checkpoints/checkpoint-186000/optimizer_state.msgpack b/outputs/checkpoints/checkpoint-186000/optimizer_state.msgpack new file mode 100644 index 0000000000000000000000000000000000000000..7fe6fd30d9b9dde6a3d2b1e25a4c7b6b365081ba --- /dev/null +++ b/outputs/checkpoints/checkpoint-186000/optimizer_state.msgpack @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e08084e7879de2059bc4fd38ff876b0bd1e987a646397956dd4bef19bb2dd1bf +size 499500278 diff --git a/outputs/checkpoints/checkpoint-186000/training_args.joblib b/outputs/checkpoints/checkpoint-186000/training_args.joblib new file mode 100644 index 0000000000000000000000000000000000000000..5855f72bcc4215a2b17083c6e9cce26965c419b6 --- /dev/null +++ b/outputs/checkpoints/checkpoint-186000/training_args.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:79196d5f797a7527e60287e48f716fd4626c5d57186661d6a214a2027998c86d +size 1873 diff --git a/outputs/checkpoints/checkpoint-186000/training_state.json b/outputs/checkpoints/checkpoint-186000/training_state.json new file mode 100644 index 0000000000000000000000000000000000000000..4323ac7250fe0f27af5929414fc764d2882db84f --- /dev/null +++ b/outputs/checkpoints/checkpoint-186000/training_state.json @@ -0,0 +1 @@ +{"step": 186001} \ No newline at end of file diff --git a/outputs/data_collator.joblib b/outputs/data_collator.joblib index 3c029096ac0457bdb70428bf5341d77100a030c6..4309dae74a6f7a811944b2d931b89db770fd4b70 100644 --- a/outputs/data_collator.joblib +++ b/outputs/data_collator.joblib @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:e02a6e9cfa63cb321cac9402efd29841b652999fcbf787800ae050e747b161ee -size 1471394 +oid sha256:0321b1a9629e1be122045cd72470365a63c8496fec109fdeec34827f01ffbb9e +size 1471424 diff --git a/outputs/events.out.tfevents.1627122688.tablespoon.2185269.3.v2 b/outputs/events.out.tfevents.1627122688.tablespoon.2185269.3.v2 new file mode 100644 index 0000000000000000000000000000000000000000..8f0eaf6975217372772f3a39393dfd5a5c87159b --- /dev/null +++ b/outputs/events.out.tfevents.1627122688.tablespoon.2185269.3.v2 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:96e9ae0a574bba9629f3ffc65a3b0714cef7fb3ee796573b1d132a3c80f25fc3 +size 40 diff --git a/outputs/events.out.tfevents.1627122817.tablespoon.2191003.3.v2 b/outputs/events.out.tfevents.1627122817.tablespoon.2191003.3.v2 new file mode 100644 index 0000000000000000000000000000000000000000..9527ac882c10e62ef45b00cecd2f5249e45722d9 --- /dev/null +++ b/outputs/events.out.tfevents.1627122817.tablespoon.2191003.3.v2 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8351632f70e7465eea0642599fff77709273a7fc9a50b607543f49754c6d694e +size 149322 diff --git a/outputs/events.out.tfevents.1627125745.tablespoon.2266135.3.v2 b/outputs/events.out.tfevents.1627125745.tablespoon.2266135.3.v2 new file mode 100644 index 0000000000000000000000000000000000000000..85b09b440fdc3ea6be0f86226d1f0d46d5689184 --- /dev/null +++ b/outputs/events.out.tfevents.1627125745.tablespoon.2266135.3.v2 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ca4fd9200f1731b7b75feb068806b3d69d179f8e5d467282752c6a8efac760c0 +size 149322 diff --git a/outputs/events.out.tfevents.1627128247.tablespoon.2330108.3.v2 b/outputs/events.out.tfevents.1627128247.tablespoon.2330108.3.v2 new file mode 100644 index 0000000000000000000000000000000000000000..463c101142153edf24da818a7fe1daaaaff80a0f --- /dev/null +++ b/outputs/events.out.tfevents.1627128247.tablespoon.2330108.3.v2 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8afd683918d2323382c5b27888561ecb6d879ebe3f6360a1c6ac02435d2f2122 +size 746450 diff --git a/outputs/flax_model.msgpack b/outputs/flax_model.msgpack index d7b3172a705c066aaf5eb63e89a07ecda3789e09..06e8410245f53b24f15363c853c91bd1f34ccfbe 100644 --- a/outputs/flax_model.msgpack +++ b/outputs/flax_model.msgpack @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:393c37966461709fe51a3b3f84befb7fa7e5030025856d171308efd40dbbc7da +oid sha256:3362fadd5306539e1775600df59349acf0fd260751129abaa45b997091d089ce size 249750019 diff --git a/outputs/optimizer_state.msgpack b/outputs/optimizer_state.msgpack index 9ecf7634d332666cefc80f7f45d1d7c2245deb9a..7fe6fd30d9b9dde6a3d2b1e25a4c7b6b365081ba 100644 --- a/outputs/optimizer_state.msgpack +++ b/outputs/optimizer_state.msgpack @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:3a33cad417a7e78eaafc1c041f93fd54ad9f63869d01e1351bac4abcd58e4eeb +oid sha256:e08084e7879de2059bc4fd38ff876b0bd1e987a646397956dd4bef19bb2dd1bf size 499500278 diff --git a/outputs/training_args.joblib b/outputs/training_args.joblib index 9ec9d51e87d8b26b07ba1d4300d451a9dfc4a8d8..5855f72bcc4215a2b17083c6e9cce26965c419b6 100644 --- a/outputs/training_args.joblib +++ b/outputs/training_args.joblib @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:8bc14fe16573d318dd510c7cfb42ebb7cc87b4dcf77e99247a2d1605cffd772b -size 1876 +oid sha256:79196d5f797a7527e60287e48f716fd4626c5d57186661d6a214a2027998c86d +size 1873 diff --git a/outputs/training_state.json b/outputs/training_state.json index 75348a914ed8eb10df17cc2276ac75eb02bd2c8c..4323ac7250fe0f27af5929414fc764d2882db84f 100644 --- a/outputs/training_state.json +++ b/outputs/training_state.json @@ -1 +1 @@ -{"step": 180001} \ No newline at end of file +{"step": 186001} \ No newline at end of file diff --git a/pytorch_model.bin b/pytorch_model.bin index 27e6d189fb42c2992169d43ce01c37b26935f53e..39c09b1e14189f7f63f524d63e1b12b7763e0df0 100644 --- a/pytorch_model.bin +++ b/pytorch_model.bin @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:9ec165e7aa9031ae2353284d56f1ebef021052449fdf013612050c2cbaa189f8 +oid sha256:15ab86b65e094a9317712403ea4263852018a6a06027f12b31d56c20302015a2 size 498858859 diff --git a/run_mlm_flax_stream.py b/run_mlm_flax_stream.py old mode 100755 new mode 100644 index dfd6e23ecfc254e9f6118c8924b12861139509e6..92581d65ea04e7a4e5434c0c31859c880be5e80c --- a/run_mlm_flax_stream.py +++ b/run_mlm_flax_stream.py @@ -25,6 +25,7 @@ import json import os import shutil import sys +import tempfile import time from collections import defaultdict from dataclasses import dataclass, field @@ -60,6 +61,8 @@ from transformers import ( TrainingArguments, is_tensorboard_available, set_seed, + FlaxRobertaForMaskedLM, + RobertaForMaskedLM, ) @@ -348,6 +351,24 @@ def save_checkpoint_files(state, data_collator, training_args, save_dir): json.dump({"step": unreplicated_state.step.item()}, f) +def restore_checkpoint(save_dir, state): + logger.info(f"Restoring checkpoint from {save_dir}") + with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f: + params = from_bytes(state.params, f.read()) + + with open(os.path.join(save_dir, "optimizer_state.msgpack"), "rb") as f: + opt_state = from_bytes(state.opt_state, f.read()) + + args = joblib.load(os.path.join(save_dir, "training_args.joblib")) + data_collator = joblib.load(os.path.join(save_dir, "data_collator.joblib")) + + with open(os.path.join(save_dir, "training_state.json"), "r") as f: + training_state = json.load(f) + step = training_state["step"] + + return params, opt_state, step, args, data_collator + + def rotate_checkpoints(path, max_checkpoints=5): paths = sorted(Path(path).iterdir(), key=os.path.getmtime)[::-1] if len(paths) > max_checkpoints: @@ -358,6 +379,27 @@ def rotate_checkpoints(path, max_checkpoints=5): os.remove(path_to_delete) +def to_f32(t): + return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t) + + +def convert(output_dir, destination_dir="./"): + shutil.copyfile(Path(output_dir) / "flax_model.msgpack", Path(destination_dir) / "flax_model.msgpack") + shutil.copyfile(Path(output_dir) / "config.json", Path(destination_dir) / "config.json") + # Saving extra files from config.json and tokenizer.json files + tokenizer = AutoTokenizer.from_pretrained(destination_dir) + tokenizer.save_pretrained(destination_dir) + + # Temporary saving bfloat16 Flax model into float32 + tmp = tempfile.mkdtemp() + flax_model = FlaxRobertaForMaskedLM.from_pretrained(destination_dir) + flax_model.params = to_f32(flax_model.params) + flax_model.save_pretrained(tmp) + # Converting float32 Flax to PyTorch + model = RobertaForMaskedLM.from_pretrained(tmp, from_flax=True) + model.save_pretrained(destination_dir, save_config=False) + + if __name__ == "__main__": # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. @@ -484,8 +526,6 @@ if __name__ == "__main__": has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: - from flax.metrics.tensorboard import SummaryWriter - summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) # Enable Weight&Biases import wandb wandb.init( @@ -496,6 +536,8 @@ if __name__ == "__main__": wandb.config.update(training_args) wandb.config.update(model_args) wandb.config.update(data_args) + from flax.metrics.tensorboard import SummaryWriter + summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( @@ -569,6 +611,42 @@ if __name__ == "__main__": # Setup train state state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw) + saved_step = -1 + if model_args.model_name_or_path and "checkpoint" in model_args.model_name_or_path: + params, opt_state, saved_step, args, data_collator = restore_checkpoint(model_args.model_name_or_path, state) + # Create learning rate schedule + warmup_fn = optax.linear_schedule( + init_value=0.0, end_value=args.learning_rate, transition_steps=args.warmup_steps + ) + decay_fn = optax.linear_schedule( + init_value=args.learning_rate, + end_value=0, + transition_steps=data_args.num_train_steps - args.warmup_steps, + ) + linear_decay_lr_schedule_fn = optax.join_schedules( + schedules=[warmup_fn, decay_fn], boundaries=[args.warmup_steps] + ) + # create adam optimizer + adamw = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + eps=training_args.adam_epsilon, + weight_decay=args.weight_decay, + mask=decay_mask_fn, + ) + state = train_state.TrainState( + step=saved_step, + apply_fn=model.__call__, + params=params, + tx=adamw, + opt_state=opt_state, + ) + # self.args = args + # data_collator = data_collator + # scheduler_fn = args.learning_rate + model.params = params + # Define gradient update step fn def train_step(state, batch, dropout_rng): @@ -636,8 +714,12 @@ if __name__ == "__main__": max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length) + last_desc = "" steps = tqdm(range(num_train_steps), desc="Training...", position=0) for step in range(num_train_steps): + if step < saved_step: + steps.update(1) + continue # ======================== Training ================================ try: samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length) @@ -692,7 +774,8 @@ if __name__ == "__main__": eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics) # Update progress bar - steps.desc = f"Step... ({step + 1}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})" + steps.desc = f"Step... ({step}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})" + last_desc = steps.desc if has_tensorboard and jax.process_index() == 0: write_eval_metric(summary_writer, eval_metrics, step) @@ -700,23 +783,50 @@ if __name__ == "__main__": # save checkpoint after eval_steps if step % training_args.save_steps == 0 and step > 0 and jax.process_index() == 0: - logger.info(f"Saving checkpoint at {step + 1} steps") + logger.info(f"Saving checkpoint at {step} steps") params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained( training_args.output_dir, params=params, - push_to_hub=training_args.push_to_hub, - commit_message=f"Saving weights and logs of step {step + 1}", + push_to_hub=False, ) save_checkpoint_files(state, data_collator, training_args, training_args.output_dir) - checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step + 1}" + checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step}" checkpoints_dir.mkdir(parents=True, exist_ok=True) - model.save_pretrained(checkpoints_dir, params=params,) + model.save_pretrained(checkpoints_dir, params=params) save_checkpoint_files(state, data_collator, training_args, checkpoints_dir) rotate_checkpoints( Path(training_args.output_dir) / "checkpoints", max_checkpoints=training_args.save_total_limit ) + convert(training_args.output_dir, "./") + model.save_pretrained( + training_args.output_dir, + params=params, + push_to_hub=training_args.push_to_hub, + commit_message=last_desc, + ) # update tqdm bar steps.update(1) + + if jax.process_index() == 0: + logger.info(f"Saving checkpoint at {step} steps") + params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) + model.save_pretrained( + training_args.output_dir, + params=params, + push_to_hub=False, + ) + save_checkpoint_files(state, data_collator, training_args, training_args.output_dir) + checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step}" + checkpoints_dir.mkdir(parents=True, exist_ok=True) + model.save_pretrained(checkpoints_dir, params=params) + save_checkpoint_files(state, data_collator, training_args, checkpoints_dir) + convert(training_args.output_dir, "./") + model.save_pretrained( + training_args.output_dir, + params=params, + push_to_hub=training_args.push_to_hub, + commit_message=last_desc or "Saving model after training", + ) diff --git a/run_stream.128.sh b/run_stream.128.sh new file mode 100755 index 0000000000000000000000000000000000000000..87e024ad9fd61b5e23e5ebd1a23cc53c1d2cc171 --- /dev/null +++ b/run_stream.128.sh @@ -0,0 +1,28 @@ +# From https://arxiv.org/pdf/1907.11692.pdf for base model +python -c "import jax; print('TPUs', jax.device_count())" +python ./run_mlm_flax_stream.py \ + --model_name_or_path="./outputs/checkpoints/checkpoint-181000/" \ + --output_dir="./outputs" \ + --model_type="roberta" \ + --config_name="./configs/base" \ + --tokenizer_name="./configs/base" \ + --dataset_name="./mc4" \ + --dataset_config_name="es" \ + --train_file="../mc4-es-train-50M-steps.jsonl" \ + --max_seq_length="128" \ + --pad_to_max_length \ + --per_device_train_batch_size="256" \ + --per_device_eval_batch_size="256" \ + --adam_beta1="0.9" \ + --adam_beta2="0.98" \ + --adam_epsilon="1e-6" \ + --learning_rate="6e-4" \ + --weight_decay="0.01" \ + --save_steps="1000" \ + --save_total_limit="5" \ + --warmup_steps="24000" \ + --overwrite_output_dir \ + --num_train_steps="250000" \ + --eval_steps="1000" \ + --dtype="bfloat16" \ + --logging_steps="500" 2>&1 | tee run_stream.128.log diff --git a/run_stream.512.sh b/run_stream.512.sh new file mode 100755 index 0000000000000000000000000000000000000000..87aacba70b52132576d269295912231499e1003a --- /dev/null +++ b/run_stream.512.sh @@ -0,0 +1,27 @@ +# From https://arxiv.org/pdf/1907.11692.pdf for base model +python -c "import jax; print('TPUs', jax.device_count())" +python ./run_mlm_flax_stream.py \ + --model_name_or_path="./outputs/checkpoints/checkpoint-24000" \ + --output_dir="./outputs" \ + --model_type="roberta" \ + --config_name="./configs/base" \ + --tokenizer_name="./configs/base" \ + --dataset_name="bertin-project/mc4-es-sampled" \ + --dataset_config_name="gaussian" \ + --max_seq_length="512" \ + --pad_to_max_length \ + --per_device_train_batch_size="48" \ + --per_device_eval_batch_size="48" \ + --adam_beta1="0.9" \ + --adam_beta2="0.98" \ + --adam_epsilon="1e-6" \ + --learning_rate="6e-4" \ + --weight_decay="0.01" \ + --save_steps="1000" \ + --save_total_limit="5" \ + --warmup_steps="500" \ + --overwrite_output_dir \ + --num_train_steps="50000" \ + --eval_steps="1000" \ + --dtype="bfloat16" \ + --logging_steps="500" 2>&1 | tee run_stream_checkpoint.log