udpated the model and script to load local data
Browse files- pytorch_model.bin +1 -1
- run_clm_flax.py +6 -1
- run_pretraining.sh +3 -2
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1444576537
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f67e392707d4b269ea616f717bb7451e79d8cc0235449e990209b12bb74aad45
|
3 |
size 1444576537
|
run_clm_flax.py
CHANGED
@@ -112,6 +112,9 @@ class DataTrainingArguments:
|
|
112 |
dataset_config_name: Optional[str] = field(
|
113 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
114 |
)
|
|
|
|
|
|
|
115 |
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
116 |
validation_file: Optional[str] = field(
|
117 |
default=None,
|
@@ -296,19 +299,21 @@ def main():
|
|
296 |
if data_args.dataset_name is not None:
|
297 |
# Downloading and loading a dataset from the hub.
|
298 |
dataset = load_dataset(
|
299 |
-
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
|
300 |
)
|
301 |
|
302 |
if "validation" not in dataset.keys():
|
303 |
dataset["validation"] = load_dataset(
|
304 |
data_args.dataset_name,
|
305 |
data_args.dataset_config_name,
|
|
|
306 |
split=f"train[:{data_args.validation_split_percentage}%]",
|
307 |
cache_dir=model_args.cache_dir,
|
308 |
)
|
309 |
dataset["train"] = load_dataset(
|
310 |
data_args.dataset_name,
|
311 |
data_args.dataset_config_name,
|
|
|
312 |
split=f"train[{data_args.validation_split_percentage}%:]",
|
313 |
cache_dir=model_args.cache_dir,
|
314 |
)
|
|
|
112 |
dataset_config_name: Optional[str] = field(
|
113 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
114 |
)
|
115 |
+
dataset_data_dir: Optional[str] = field(
|
116 |
+
default=None, metadata={"help": "The name of the data directory."}
|
117 |
+
)
|
118 |
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
119 |
validation_file: Optional[str] = field(
|
120 |
default=None,
|
|
|
299 |
if data_args.dataset_name is not None:
|
300 |
# Downloading and loading a dataset from the hub.
|
301 |
dataset = load_dataset(
|
302 |
+
data_args.dataset_name, data_args.dataset_config_name, data_dir=data_args.dataset_data_dir, cache_dir=model_args.cache_dir, keep_in_memory=False
|
303 |
)
|
304 |
|
305 |
if "validation" not in dataset.keys():
|
306 |
dataset["validation"] = load_dataset(
|
307 |
data_args.dataset_name,
|
308 |
data_args.dataset_config_name,
|
309 |
+
data_dir=data_args.dataset_data_dir,
|
310 |
split=f"train[:{data_args.validation_split_percentage}%]",
|
311 |
cache_dir=model_args.cache_dir,
|
312 |
)
|
313 |
dataset["train"] = load_dataset(
|
314 |
data_args.dataset_name,
|
315 |
data_args.dataset_config_name,
|
316 |
+
data_dir=data_args.dataset_data_dir,
|
317 |
split=f"train[{data_args.validation_split_percentage}%:]",
|
318 |
cache_dir=model_args.cache_dir,
|
319 |
)
|
run_pretraining.sh
CHANGED
@@ -9,8 +9,9 @@ export WANDB_LOG_MODEL="true"
|
|
9 |
--model_type="gpt2" \
|
10 |
--config_name="${MODEL_DIR}" \
|
11 |
--tokenizer_name="${MODEL_DIR}" \
|
12 |
-
--dataset_name="
|
13 |
-
--dataset_config_name="
|
|
|
14 |
--do_train --do_eval \
|
15 |
--block_size="512" \
|
16 |
--per_device_train_batch_size="24" \
|
|
|
9 |
--model_type="gpt2" \
|
10 |
--config_name="${MODEL_DIR}" \
|
11 |
--tokenizer_name="${MODEL_DIR}" \
|
12 |
+
--dataset_name="./datasets/id_collection" \
|
13 |
+
--dataset_config_name="id_collection" \
|
14 |
+
--dataset_data_dir="/data/collection" \
|
15 |
--do_train --do_eval \
|
16 |
--block_size="512" \
|
17 |
--per_device_train_batch_size="24" \
|