Spaces:
Running
Running
# Copyright (c) Guangsheng Bao. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import time | |
import numpy as np | |
import datasets | |
import torch | |
import random | |
import argparse | |
import os | |
import json | |
import custom_datasets | |
from model import load_tokenizer, load_model | |
def stats_str(data): | |
if type(data) == dict: | |
mean_orig = np.mean([len(v.split()) for v in data['original']]) | |
mean_samp = np.mean([len(v.split()) for v in data['sampled']]) | |
return f'{mean_orig:.0f} words (original), {mean_samp:.0f} words (sampled).' | |
else: | |
mean_orig = np.mean([len(v['original'].split()) for v in data]) | |
mean_samp = np.mean([len(v['sampled'].split()) for v in data]) | |
mean_perturb_orig = np.mean([np.mean([len(p.split()) for p in v['perturbed_original']]) for v in data]) | |
mean_perturb_samp = np.mean([np.mean([len(p.split()) for p in v['perturbed_sampled']]) for v in data]) | |
return f'{mean_orig:.0f} words (original), {mean_samp:.0f} words (sampled), {mean_perturb_orig:.0f} words (perturb original), {mean_perturb_samp:.0f} words (perturb sampled).' | |
def save_data(output_file, args, data): | |
# write args to file | |
args_file = f"{output_file}.args.json" | |
with open(args_file, "w") as fout: | |
json.dump(args, fout, indent=4) | |
print(f"Args written into {args_file}") | |
# write the data to a json file in the save folder | |
data_file = f"{output_file}.raw_data.json" | |
with open(data_file, "w") as fout: | |
json.dump(data, fout, indent=4) | |
print(f"Raw data written into {data_file}: {stats_str(data)}") | |
def load_data(input_file): | |
# load args from file | |
args_file = f"{input_file}.args.json" | |
with open(args_file, "r") as fin: | |
args = json.load(fin) | |
print(f"Args loaded from {args_file}") | |
# load the data from file | |
data_file = f"{input_file}.raw_data.json" | |
with open(data_file, "r") as fin: | |
data = json.load(fin) | |
print(f"Raw data loaded from {data_file}: {stats_str(data)}") | |
return args, data | |
def convert_data(input_file, output_file, max_words): | |
def _reduce(text): | |
lines = [] | |
nwords = 0 | |
for line in text.split('\n'): | |
if nwords >= max_words: | |
break | |
words = line.split() | |
words = words[:max_words - nwords] | |
lines.append(' '.join(words)) | |
nwords += len(words) | |
return '\n'.join(lines) | |
args, data = load_data(input_file) | |
if type(data) == dict: | |
data['original'] = [_reduce(x) for x in data['original']] | |
data['sampled'] = [_reduce(x) for x in data['sampled']] | |
else: | |
for item in data: | |
item['original'] = _reduce(item['original']) | |
item['sampled'] = _reduce(item['sampled']) | |
item['perturbed_original'] = [_reduce(x) for x in item['perturbed_original']] | |
item['perturbed_sampled'] = [_reduce(x) for x in item['perturbed_sampled']] | |
save_data(output_file, args, data) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--input_path', type=str, default="./exp_gpt3to4/data/") | |
parser.add_argument('--output_path', type=str, default="./exp_maxlen150/data/") | |
parser.add_argument('--max_words', type=int, default=150) | |
args = parser.parse_args() | |
import glob | |
import os.path as path | |
for file_name in glob.glob(f'{args.input_path}/*.raw_data.json'): | |
print(file_name) | |
file_name = path.basename(file_name).replace('.raw_data.json', '') | |
convert_data(path.join(args.input_path, file_name), path.join(args.output_path, file_name), args.max_words) | |