Spaces:
Runtime error
Runtime error
# 786/110 | |
import os | |
import sys | |
import time | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import DataUtils | |
import Settings as settings | |
from DataUtils import generate_random_eqn, multidim_dataset_from_eqn_list | |
from DataUtils import simple_eqn_to_str | |
from MLP_Model import MLP_Model | |
# from models.smtree_model import SMTree_Model | |
# from models.smftree_model import SMFTree_Model | |
# from models.smfftree_model import SMFFTree_Model | |
# from models.addoptree_model import AddOpTree_Model | |
# from models.crtree_model import CRTree_Model | |
# from models.sptree_model import SpTree_Model | |
# from models.lbl_model import LBLTree_Model | |
from SymbolicFunctionLearner import SFL | |
# from models.mcts_model import MCTS_Model | |
from gp_model import Genetic_Model | |
""" Test Hyperparameters """ | |
num_eqns_to_create = 150 | |
num_training = 1000 | |
num_valid = 1000 | |
num_test = 1000 | |
allowable_ops = settings.function_set.copy() | |
num_vars = 1 # 10 | |
num_max_levels = 3 # 4 | |
settings.n_tree_layers = num_max_levels | |
settings.num_features = num_vars | |
settings.show_output = False | |
settings.keep_logs = False | |
settings.mode = "sr" | |
show_found_eqns = True | |
use_constants_in_eqns = True | |
if not os.path.exists('images'): | |
os.makedirs('images') | |
def plot_all_models(eqn_str, eqn_number, all_models, this_train_x, this_train_y, | |
this_test_x, this_test_y): | |
plt.figure() | |
plt.title('Compare models: {}'.format(eqn_str)) | |
min_y = np.min(this_test_y) | |
max_y = np.max(this_test_y) | |
y_range = max_y - min_y | |
plt.scatter(this_train_x, this_train_y, color='xkcd:dark pink', | |
marker='o', s=100, label='Training set') | |
plt.scatter(this_test_x, this_test_y, color='gray', alpha=0.5, marker='.', label='Ground truth') | |
for i in range(len(all_models)): | |
this_model = all_models[i] | |
test_hat_y = this_model.predict(this_test_x) | |
plt.scatter(this_test_x, | |
test_hat_y, alpha=0.7, marker='.', label='{}'.format(this_model.name)) | |
for xc in settings.train_scope: | |
plt.axvline(x=xc, color='k', linestyle='dashed', linewidth=2) | |
plt.ylim([min_y - 0.5 * y_range, max_y + 0.5 * y_range]) | |
plt.legend() | |
plt.savefig("images/eqn_{}.png".format(eqn_number)) | |
plt.close() | |
models_to_test = [] | |
models_to_test.append(Genetic_Model) | |
models_to_test.append(MLP_Model) | |
# models_to_test.append(Tree_Model) | |
models_to_test.append(SFL) | |
# models_to_test.append(SMTree_Model) | |
# models_to_test.append(SMFTree_Model) | |
# models_to_test.append(SMFFTree_Model) | |
# models_to_test.append(AddOpTree_Model) | |
# models_to_test.append(SpTree_Model) | |
lists_of_error_scores = [] | |
lists_of_iter_times = [] | |
all_models = [] | |
for model_type in models_to_test: | |
all_models.append(model_type()) | |
lists_of_error_scores.append([]) | |
lists_of_iter_times.append([]) | |
# all_models.append(SpTree_Model(use_scopers=True)) | |
# lists_of_error_scores.append([]) | |
# lists_of_iter_times.append([]) | |
seen_eqns = [] | |
winning_entries = [] | |
valid_err = 0 | |
print("Starting program.") | |
print(settings.num_features) | |
start_time = time.time() | |
for eqn_n in range(1, num_eqns_to_create + 1): | |
print("----------------") | |
# Create some random equation | |
current_eqn_as_list = generate_random_eqn(allowable_ops, num_vars, num_max_levels, | |
allow_constants=use_constants_in_eqns) | |
current_eqn_as_str = simple_eqn_to_str(current_eqn_as_list) | |
# should be a new equation | |
while current_eqn_as_str in seen_eqns: | |
current_eqn_as_list = generate_random_eqn(allowable_ops, num_vars, num_max_levels, | |
allow_constants=use_constants_in_eqns) | |
current_eqn_as_str = simple_eqn_to_str(current_eqn_as_list) | |
seen_eqns.append(current_eqn_as_str) | |
print(current_eqn_as_list) | |
print("Random equation {} of {}".format(eqn_n, num_eqns_to_create)) | |
print("True function: {}\n".format(current_eqn_as_str)) | |
with open("images/compare_test_output.txt", "a") as output_file: | |
output_file.write("\n{}\nTrue equation:\n{}\n".format(eqn_n, current_eqn_as_str)) | |
# Create a dataset with that equation | |
train_data_x, train_data_y = multidim_dataset_from_eqn_list(current_eqn_as_list, num_training, n_vars=num_vars) | |
test_data_x, test_data_y = multidim_dataset_from_eqn_list(current_eqn_as_list, num_test, n_vars=num_vars, | |
min_x=settings.test_scope[0], | |
max_x=settings.test_scope[1]) | |
# train_data_x, train_data_y = create_dataset_from_eqn_list(current_eqn_as_list, num_vars, num_training, | |
# settings.train_scope[0], settings.train_scope[1]) | |
# test_data_x, test_data_y = create_dataset_from_eqn_list(current_eqn_as_list, num_vars, num_test, | |
# settings.test_scope[0], settings.test_scope[1]) | |
for model_i in range(len(all_models)): | |
model = all_models[model_i] | |
model.reset() | |
itertime_start = time.time() | |
# if False: # model_i >= len(all_models) - 2: | |
# model.repeat_train(train_data_x, train_data_y) | |
# else: | |
model_eqn, _, best_err = model.repeat_train(train_data_x, train_data_y, | |
test_x=test_data_x, test_y=test_data_y, | |
verbose=False) | |
if show_found_eqns: | |
print("{} function: {}".format(model.name, model_eqn)[:550]) | |
# Test model on that equation | |
# test_err = model.test(test_data_x, test_data_y) | |
test_err = max(np.exp(-10), best_err) # data_utils.test_from_formula(model_eqn, test_data_x, test_data_y) | |
print(" ---> {} Test Error: {:.5f}".format(model.short_name, test_err)) | |
lists_of_error_scores[model_i].extend([test_err]) | |
lists_of_iter_times[model_i].append(time.time() - itertime_start) | |
sys.stdout.flush() | |
# y_gold_list = list(model.sess.run(model.y_gold, | |
# feed_dict={model.data_x: np.reshape(test_data_x[:4][:], [-1, num_vars]), | |
# model.data_y: np.reshape(test_data_y[:4][:], [-1, 1]), | |
# model.var_random_u: model.var_random_selector, | |
# model.op_random_u: model.op_random_selector}).reshape(1, -1)[0]) | |
# y_hat_list2 = data_utils.predict_from_formula(model.get_simple_formula(digits=4), test_data_x[:4][:]) | |
# | |
# print('Performance on sample validation data:') | |
# for feature_i in range(model.n_input_variables): | |
# print('x{}: '.format(feature_i + 1), | |
# ['{:7.4f}'.format(yyy[feature_i]) for yyy in test_data_x[:4][:]]) | |
# print("-----------------------------------------------------") | |
# print('y_gold: ', ['{:7.4f}'.format(yyy) for yyy in y_gold_list]) | |
# print('y_hat2: ', ['{:7.4f}'.format(yyy) for yyy in y_hat_list2]) | |
if test_err < 0.035: | |
winning_entries.append("{} - {}".format(current_eqn_as_str, model.short_name)) | |
with open("images/compare_test_output.txt", "a") as output_file: | |
output_file.write("{}: {}\n{}\n".format(model.short_name, test_err, model_eqn)) | |
print() | |
DataUtils.plot_hist_of_errors(lists_of_error_scores, all_models, eqn_n) | |
# todo: something is wrong with this function!! | |
# todo: it plots the model as is (after last iter), not at best iter, of repeat_train | |
# DataUtils.plot_all_models_predicted_actual(all_models, test_data_x, test_data_y, | |
# set_name="Eqn {}: {}".format(eqn_n, current_eqn_as_str), | |
# fig_name="eqn_{}".format(eqn_n)) | |
plt.figure() | |
x_axis = [i + 1 for i in range(eqn_n)] | |
for iter_times_i in range(len(lists_of_iter_times)): | |
plt.plot(x_axis, lists_of_iter_times[iter_times_i], label=all_models[iter_times_i].short_name) | |
plt.xlabel("Iteration") | |
plt.ylabel("Running time") | |
plt.legend() | |
plt.savefig("images/time_curve.png") | |
plt.close() | |
# if num_vars == 1: | |
# plot_all_models(current_eqn_as_str, eqn_n, all_models, | |
# train_data_x, train_data_y, | |
# test_data_x, test_data_y) | |
running_time = time.time() - start_time | |
print("Done. Took {} seconds.\n".format(running_time)) | |
# | |
# plt.figure() | |
# plt.hist([list_of_error_scores_GP, list_of_error_scores_MLP, list_of_error_scores_Tree], label=['GP', 'MLP', 'Tree']) | |
# plt.legend(loc='upper right') | |
# plt.show() | |
# | |
# plt.figure() | |
# plt.hist(list_of_error_scores_GP, edgecolor='k', linewidth=1.2) | |
# plt.show() | |
# | |
# plt.figure() | |
# plt.hist(list_of_error_scores_GP, cumulative=True) | |
# plt.show() | |
# | |