Spaces:
Runtime error
Runtime error
""""""""""""""""""""""""""""""""" | |
Do not run or modify this file. | |
For running: DiffEqnSolver.py | |
For modifying: settings.py | |
""""""""""""""""""""""""""""""""" | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import sympy | |
import tensorflow as tf | |
from matplotlib import cm | |
from sympy import expand, sympify, lambdify | |
import Settings | |
def safe_abs(x): | |
return np.sqrt(x * x + Settings.eps) | |
def safe_div(x, y): | |
return np.sign(y) * x / safe_abs(y) | |
def tf_diff_abs(x): | |
return tf.sqrt(tf.square(x) + Settings.eps) | |
def tf_diff_sqrt(x): | |
return tf.sqrt(tf_diff_abs(x)) | |
def tf_diff_log(x): | |
return tf.math.log(tf_diff_abs(x)) | |
def our_tanh(x, factor=1000): | |
return factor * tf.tanh(x / factor) | |
def spike(x): | |
return 1.0 / (1 + 200 * tf.square(x)) | |
# return tf.math.exp(-10 * tf.square(x)) | |
def true_function(input_x): | |
return predict_from_formula(Settings.true_eqn, input_x) | |
def is_float(value): | |
try: | |
float(value) | |
return True | |
except ValueError: | |
return False | |
# Function to generate random equation as operator/input list | |
# Variables are numbered 1 ... n, and 0 does not appear | |
# Constants appear as [float] e.g [3.14] | |
def generate_random_eqn_two_list(op_list, n_vars, n_levels, allow_constants=True): | |
eqn_ops = list(np.random.choice(op_list, size=int(2 ** n_levels) - 1, replace=True)) | |
if allow_constants: | |
eqn_vars = list(np.random.choice(range(1, max(int(n_vars * 1.6), n_vars + 2)), | |
size=int(2 ** n_levels), replace=True)) | |
for i in range(len(eqn_vars)): | |
if eqn_vars[i] >= n_vars + 1: | |
eqn_vars[i] = [np.random.uniform(Settings.test_scope[0], Settings.test_scope[1])] | |
else: | |
eqn_vars = list(np.random.choice(range(1, 1 + n_vars), size=int(2 ** n_levels), replace=True)) | |
return [eqn_ops, eqn_vars] | |
# Function to generate random equation as operator/input list and weight/bias list | |
# Variables are numbered 1 ... n, and 0 does not appear | |
# Constants appear in weight and bias lists. | |
# const_ratio determines how many weights are not 1, and how many biases are not 0 | |
def generate_random_eqn(op_list, n_vars, n_levels, allow_constants=True, const_ratio=0.8): | |
eqn_ops = list(np.random.choice(op_list, size=int(2 ** n_levels) - 1, replace=True)) | |
eqn_vars = list(np.random.choice(range(1, (n_vars + 1)), size=int(2 ** n_levels), replace=True)) | |
max_bound = max(np.abs(Settings.test_scope[0]), np.abs(Settings.test_scope[1])) | |
eqn_weights = list(np.random.uniform(-1 * max_bound, max_bound, size=len(eqn_vars))) | |
eqn_biases = list(np.random.uniform(-1 * max_bound, max_bound, size=len(eqn_vars))) | |
if not allow_constants: | |
const_ratio = 0.0 | |
random_const_chooser_w = np.random.uniform(0, 1, len(eqn_weights)) | |
random_const_chooser_b = np.random.uniform(0, 1, len(eqn_biases)) | |
for i in range(len(eqn_weights)): | |
if random_const_chooser_w[i] >= const_ratio: | |
eqn_weights[i] = 1 | |
if random_const_chooser_b[i] >= const_ratio: | |
eqn_biases[i] = 0 | |
return [eqn_ops, eqn_vars, eqn_weights, eqn_biases] | |
# Function to create a multidim input data set given an operator/input list | |
def generate_data(n_points, n_vars=Settings.num_features, | |
n_input_dims=Settings.num_dims_per_feature, | |
min_x=Settings.train_scope[0], max_x=Settings.train_scope[1], | |
avoid_zero=False): | |
if not avoid_zero: | |
x_data = [np.random.uniform(min_x, max_x, size=[n_input_dims, n_vars]) for _ in range(n_points)] | |
else: | |
x_data = [] | |
for _ in range(n_points): | |
candidate = np.random.uniform(min_x, max_x, size=[n_input_dims, n_vars]) | |
while np.linalg.norm(candidate) < 0.1: | |
candidate = np.random.uniform(min_x, max_x, size=[n_input_dims, n_vars]) | |
x_data.append(candidate) | |
return np.array(x_data) | |
# Function to create a data set given an operator/input list | |
def create_dataset_from_eqn_list(eqn_as_list, n_vars, n_points, min_x, max_x): | |
x_data = [list(np.random.uniform(min_x, max_x, n_vars)) for _ in range(n_points)] | |
y_data = [evaluate_eqn_list_on_datum(eqn_as_list, x_data_i) + np.random.normal(0, 0.05) for x_data_i in x_data] | |
return [np.array(x_data), np.array(y_data)] | |
# Function to create a multidim data set given an operator/input list | |
def multidim_dataset_from_eqn_list(eqn_as_list, n_points, | |
n_vars=Settings.num_features, | |
n_input_dims=Settings.num_dims_per_feature, | |
n_output_dims=Settings.n_dims_in_output, | |
min_x=Settings.train_scope[0], max_x=Settings.train_scope[1], | |
avoid_zero=False): | |
if not avoid_zero: | |
x_data = [np.random.uniform(min_x, max_x, size=[n_input_dims, n_vars]) for _ in range(n_points)] | |
else: | |
x_data = [] | |
for _ in range(n_points): | |
candidate = np.random.uniform(min_x, max_x, size=[n_input_dims, n_vars]) | |
while np.linalg.norm(candidate) < 0.1: | |
candidate = np.random.uniform(min_x, max_x, size=[n_input_dims, n_vars]) | |
x_data.append(candidate) | |
y_data = [evaluate_eqn_list_on_multidim_datum(eqn_as_list, x_data_i) # + np.random.normal(0, 0.05) | |
for x_data_i in x_data] | |
if n_output_dims == 1: | |
y_data = [np.mean(old_y) for old_y in y_data] | |
return [np.array(x_data), np.reshape(np.array(y_data), [n_points, n_output_dims, 1])] | |
def make_y_multi_safe(old_y, n_dims_per_input_var=1, n_dims_in_output=1): | |
if isinstance(old_y, list): | |
new_y = np.array(old_y) | |
new_y.reshape([-1, n_dims_in_output, 1]) | |
else: | |
new_y = old_y.copy() | |
if len(new_y.shape) == 1: | |
assert (n_dims_in_output == 1) | |
new_y = [[[y_value] for _ in range(n_dims_per_input_var)] for y_value in new_y] | |
new_y = np.array(new_y) | |
elif len(new_y.shape) == 2: | |
assert (n_dims_in_output == 1) | |
new_y = [[y_value for _ in range(n_dims_per_input_var)] for y_value in new_y] | |
new_y = np.array(new_y) | |
elif new_y.shape[1] < n_dims_per_input_var: | |
assert (n_dims_in_output == 1) | |
new_y = [[y_value[0] for _ in range(n_dims_per_input_var)] for y_value in new_y] | |
new_y = np.array(new_y) | |
return new_y | |
# Function to evaluate equation (in two-list format) on a data point | |
def evaluate_eqn_list_on_datum_two_list(eqn_as_list, input_x): | |
eqn_ops = eqn_as_list[0] | |
eqn_vars = eqn_as_list[1] | |
current_op = eqn_ops[0] | |
if len(eqn_ops) == 1: | |
if type(eqn_vars[0]) is list: | |
left_side = eqn_vars[0][0] | |
else: | |
left_side = input_x[eqn_vars[0] - 1] | |
if type(eqn_vars[1]) is list: | |
right_side = eqn_vars[1][0] | |
else: | |
right_side = input_x[eqn_vars[1] - 1] | |
else: | |
split_point = int((len(eqn_ops) + 1) / 2) | |
left_ops = eqn_ops[1:split_point] | |
right_ops = eqn_ops[split_point:] | |
left_vars = eqn_vars[:split_point] | |
right_vars = eqn_vars[split_point:] | |
left_side = evaluate_eqn_list_on_datum_two_list([left_ops, left_vars], input_x) | |
right_side = evaluate_eqn_list_on_datum_two_list([right_ops, right_vars], input_x) | |
if current_op == 'id': | |
return left_side | |
if current_op == 'sqrt': | |
return np.sqrt(np.abs(left_side)) | |
if current_op == 'log': | |
return np.log(np.sqrt(left_side * left_side + 1e-10)) | |
if current_op == 'sin': | |
return np.sin(left_side) | |
if current_op == 'exp': | |
return np.exp(left_side) | |
if current_op == 'add': | |
return left_side + right_side | |
if current_op == 'mul': | |
return left_side * right_side | |
if current_op == 'sub': | |
return left_side - right_side | |
if current_op == 'div': | |
return safe_div(left_side, right_side) | |
return None | |
# Function to evaluate equation (in list format) on a data point | |
def evaluate_eqn_list_on_datum(eqn_as_list, input_x): | |
eqn_ops = eqn_as_list[0] | |
eqn_vars = eqn_as_list[1] | |
eqn_weights = eqn_as_list[2] | |
eqn_biases = eqn_as_list[3] | |
current_op = eqn_ops[0] | |
if len(eqn_ops) == 1: | |
left_side = eqn_weights[0] * input_x[eqn_vars[0] - 1] + eqn_biases[0] | |
right_side = eqn_weights[1] * input_x[eqn_vars[1] - 1] + eqn_biases[1] | |
else: | |
split_point = int((len(eqn_ops) + 1) / 2) | |
left_ops = eqn_ops[1:split_point] | |
right_ops = eqn_ops[split_point:] | |
left_vars = eqn_vars[:split_point] | |
right_vars = eqn_vars[split_point:] | |
left_weights = eqn_weights[:split_point] | |
right_weights = eqn_weights[split_point:] | |
left_biases = eqn_biases[:split_point] | |
right_biases = eqn_biases[split_point:] | |
left_side = evaluate_eqn_list_on_datum([left_ops, left_vars, left_weights, left_biases], input_x) | |
right_side = evaluate_eqn_list_on_datum([right_ops, right_vars, right_weights, right_biases], input_x) | |
if current_op == 'id': | |
return left_side | |
if current_op == 'sqrt': | |
return np.sqrt(np.abs(left_side)) | |
if current_op == 'log': | |
return np.log(np.sqrt(left_side * left_side + 1e-10)) | |
if current_op == 'sin': | |
return np.sin(left_side) | |
if current_op == 'exp': | |
return np.exp(left_side) | |
if current_op == 'add': | |
return left_side + right_side | |
if current_op == 'mul': | |
return left_side * right_side | |
if current_op == 'sub': | |
return left_side - right_side | |
if current_op == 'div': | |
return safe_div(left_side, right_side) | |
return None | |
# Function to evaluate equation (in two-list format) on a data point | |
def evaluate_eqn_list_on_multidim_datum_two_list(eqn_as_list, input_x): | |
eqn_ops = eqn_as_list[0] | |
eqn_vars = eqn_as_list[1] | |
current_op = eqn_ops[0] | |
if len(eqn_ops) == 1: | |
if type(eqn_vars[0]) is list: | |
left_side = eqn_vars[0][0] | |
else: | |
left_side = input_x[:, eqn_vars[0] - 1] | |
if type(eqn_vars[1]) is list: | |
right_side = eqn_vars[1][0] | |
else: | |
right_side = input_x[:, eqn_vars[1] - 1] | |
else: | |
split_point = int((len(eqn_ops) + 1) / 2) | |
left_ops = eqn_ops[1:split_point] | |
right_ops = eqn_ops[split_point:] | |
left_vars = eqn_vars[:split_point] | |
right_vars = eqn_vars[split_point:] | |
left_side = evaluate_eqn_list_on_multidim_datum_two_list([left_ops, left_vars], input_x) | |
right_side = evaluate_eqn_list_on_multidim_datum_two_list([right_ops, right_vars], input_x) | |
if current_op == 'id': | |
return left_side | |
if current_op == 'sqrt': | |
return np.sqrt(np.abs(left_side)) | |
if current_op == 'log': | |
return np.log(np.sqrt(left_side * left_side + 1e-10)) | |
if current_op == 'sin': | |
return np.sin(left_side) | |
if current_op == 'exp': | |
return np.exp(left_side) | |
if current_op == 'add': | |
return left_side + right_side | |
if current_op == 'mul': | |
return left_side * right_side | |
if current_op == 'sub': | |
return left_side - right_side | |
if current_op == 'div': | |
return safe_div(left_side, right_side) | |
return None | |
# Function to evaluate equation (in list format) on a data point | |
def evaluate_eqn_list_on_multidim_datum(eqn_as_list, input_x): | |
eqn_ops = eqn_as_list[0] | |
eqn_vars = eqn_as_list[1] | |
eqn_weights = eqn_as_list[2] | |
eqn_biases = eqn_as_list[3] | |
current_op = eqn_ops[0] | |
if len(eqn_ops) == 1: | |
left_side = eqn_weights[0] * input_x[:, eqn_vars[0] - 1] + eqn_biases[0] | |
right_side = eqn_weights[1] * input_x[:, eqn_vars[1] - 1] + eqn_biases[1] | |
else: | |
split_point = int((len(eqn_ops) + 1) / 2) | |
left_ops = eqn_ops[1:split_point] | |
right_ops = eqn_ops[split_point:] | |
left_vars = eqn_vars[:split_point] | |
right_vars = eqn_vars[split_point:] | |
left_weights = eqn_weights[:split_point] | |
right_weights = eqn_weights[split_point:] | |
left_biases = eqn_biases[:split_point] | |
right_biases = eqn_biases[split_point:] | |
left_side = evaluate_eqn_list_on_multidim_datum([left_ops, left_vars, left_weights, left_biases], input_x) | |
right_side = evaluate_eqn_list_on_multidim_datum([right_ops, right_vars, right_weights, right_biases], input_x) | |
if current_op == 'id': | |
return left_side | |
if current_op == 'sqrt': | |
return np.sqrt(np.abs(left_side)) | |
if current_op == 'log': | |
return np.log(np.sqrt(left_side * left_side + 1e-10)) | |
if current_op == 'sin': | |
return np.sin(left_side) | |
if current_op == 'exp': | |
return np.exp(left_side) | |
if current_op == 'add': | |
return left_side + right_side | |
if current_op == 'mul': | |
return left_side * right_side | |
if current_op == 'sub': | |
return left_side - right_side | |
if current_op == 'div': | |
return safe_div(left_side, right_side) | |
return None | |
def choices_to_init_weight_matrix(choice_list, all_choices): | |
init_weight_matrix = np.zeros(shape=[len(choice_list), len(all_choices)]) | |
for row in range(len(choice_list)): | |
if choice_list[row] in all_choices: | |
init_weight_matrix[row][all_choices.index(choice_list[row])] = Settings.init_weight_value | |
elif isinstance(choice_list[row], str) and "not" in choice_list[row] and choice_list[row].index("not") == 0 and \ | |
choice_list[row][len("not"):] in all_choices: | |
init_weight_matrix[row][all_choices.index(choice_list[row][len("not"):])] = -100 | |
return init_weight_matrix.T | |
def predict_from_formula(formula_str, x_values): | |
if Settings.num_features == 1: | |
x_variables = [["x"]] | |
else: | |
x_variables = [["x{}".format(var_i + 1) for var_i in range(Settings.num_features)]] | |
f = lambdify(x_variables, formula_str, 'numpy') | |
if isinstance(x_values, list): | |
return [f(x_values[row_i]) for row_i in range(len(x_values))] | |
elif len(x_values.shape) == 2: | |
return [f(x_values[row_i, :]) for row_i in range(x_values.shape[0])] | |
else: | |
return [f(x_values[row_i, :, :].reshape([-1, 1])) for row_i in range(x_values.shape[0])] | |
def leaves_up_from_dfs_order(orig_list, num_layers): | |
if num_layers <= 1: | |
return orig_list | |
index_list = [[i, i+1] for i in range(0, int(2**(num_layers-1)), 2)] | |
# print("start with: {}".format(index_list)) | |
last_value=index_list[-1][-1] | |
while len(index_list) > 1: | |
new_index_list = [] | |
for j in range(int(len(index_list) / 2)): | |
last_value += 1 | |
new_list = [last_value] | |
new_list.extend(index_list[2*j]) | |
last_value += 1 | |
new_list.append(last_value) | |
new_list.extend(index_list[2*j+1]) | |
new_index_list.append(new_list) | |
index_list = new_index_list | |
# print(" index list is now: {}".format(index_list)) | |
final_index_list = [last_value + 1] | |
final_index_list.extend(index_list[0]) | |
ret_val = [i for i in range(len(final_index_list))] | |
for i in range(len(final_index_list)): | |
ret_val[final_index_list[i]] = orig_list[i] | |
return ret_val # [orig_list[i] for i in final_index_list] | |
def simplify_formula(formula_to_simplify, digits=None): | |
if len("{}".format(formula_to_simplify)) > 1500: | |
return "{}".format(expand(formula_to_simplify)) | |
orig_form_str = sympify(formula_to_simplify) | |
if len("{}".format(orig_form_str)) > 1000: | |
return "{}".format(expand(orig_form_str)) | |
if len("{}".format(orig_form_str)) < 700: | |
# orig_form_str = simplify(expand(orig_form_str)) | |
orig_form_str = expand(orig_form_str) | |
rounded = orig_form_str | |
for a in sympy.preorder_traversal(orig_form_str): | |
if isinstance(a, sympy.Float): | |
if digits is not None: | |
if np.abs(a) < 10**(-1*digits): | |
rounded = rounded.subs(a, 0) | |
else: | |
rounded = rounded.subs(a, round(a, digits)) | |
elif np.abs(a) < Settings.big_eps: | |
rounded = rounded.subs(a, 0) | |
return "{}".format(rounded) | |
def eqn_to_str_two_list(eqn_as_list, var_y_index=9999, unary_use_both=False): | |
eqn_ops = eqn_as_list[0] | |
eqn_vars = eqn_as_list[1] | |
current_op = eqn_ops[0] | |
# print("eqn_to_str:") | |
# print(eqn_ops) | |
# print(eqn_vars) | |
if len(eqn_ops) == 1: | |
if type(eqn_vars[0]) is list: | |
left_side = "{:.3f}".format(eqn_vars[0][0]) | |
elif eqn_vars[0] == var_y_index: | |
left_side = "y" | |
else: | |
left_side = "x{}".format(eqn_vars[0]) | |
if type(eqn_vars[1]) is list: | |
right_side = "{:.3f}".format(eqn_vars[1][0]) | |
elif eqn_vars[1] == var_y_index: | |
right_side = "y" | |
else: | |
right_side = "x{}".format(eqn_vars[1]) | |
else: | |
split_point = int((len(eqn_ops) + 1) / 2) | |
left_ops = eqn_ops[1:split_point] | |
right_ops = eqn_ops[split_point:] | |
left_vars = eqn_vars[:split_point] | |
right_vars = eqn_vars[split_point:] | |
left_side = eqn_to_str_two_list([left_ops, left_vars]) | |
right_side = eqn_to_str_two_list([right_ops, right_vars]) | |
left_is_float = False | |
right_is_float = False | |
left_value = np.nan | |
right_value = np.nan | |
if is_float(left_side): | |
left_value = float(left_side) | |
left_is_float = True | |
if is_float(right_side): | |
right_value = float(right_side) | |
right_is_float = True | |
if current_op == 'id': | |
return left_side | |
if current_op == 'sqrt': | |
if left_is_float: | |
return "{:.3f}".format(np.sqrt(np.abs(left_value))) | |
return "sqrt({})".format(left_side) | |
if current_op == 'log': | |
if left_is_float: | |
return "{:.3f}".format(np.math.log(safe_abs(left_value))) | |
return "log({})".format(left_side) | |
if current_op == 'sin': | |
if left_is_float: | |
return "{:.3f}".format(np.sin(left_value)) | |
return "sin({})".format(left_side) | |
if current_op == 'exp': | |
if left_is_float: | |
return "{:.3f}".format(np.exp(left_value)) | |
return "exp({})".format(left_side) | |
if current_op == 'add': | |
if left_is_float and right_is_float: | |
return "{:.3f}".format(left_value + right_value) | |
return "({} + {})".format(left_side, right_side) | |
if current_op == 'mul': | |
if left_is_float and right_is_float: | |
return "{:.3f}".format(left_value * right_value) | |
return "({} * {})".format(left_side, right_side) | |
if current_op == 'sub': | |
if left_is_float and right_is_float: | |
return "{:.3f}".format(left_value - right_value) | |
return "({} - {})".format(left_side, right_side) | |
if current_op == 'div': | |
if left_is_float and right_is_float: | |
return "{:.3f}".format(safe_div(left_value, right_value)) | |
return "({} / {})".format(left_side, right_side) | |
return None | |
def eqn_to_str(eqn_as_list, var_y_index=9999, unary_use_both=False): | |
eqn_ops = eqn_as_list[0] | |
eqn_vars = eqn_as_list[1] | |
eqn_weights = eqn_as_list[2] | |
eqn_biases = eqn_as_list[3] | |
current_op = eqn_ops[0] | |
# print("eqn_to_str:") | |
# print(eqn_ops) | |
# print(eqn_vars) | |
if len(eqn_ops) == 1: | |
if eqn_vars[0] == var_y_index: | |
left_side = "y" | |
else: | |
left_side = "({} * x{} + {})".format(eqn_weights[0], eqn_vars[0], eqn_biases[0]) | |
if eqn_vars[1] == var_y_index: | |
right_side = "y" | |
else: | |
right_side = "({} * x{} + {})".format(eqn_weights[1], eqn_vars[1], eqn_biases[1]) | |
else: | |
split_point = int((len(eqn_ops) + 1) / 2) | |
left_ops = eqn_ops[1:split_point] | |
right_ops = eqn_ops[split_point:] | |
left_vars = eqn_vars[:split_point] | |
right_vars = eqn_vars[split_point:] | |
left_weights = eqn_weights[:split_point] | |
right_weights = eqn_weights[split_point:] | |
left_biases = eqn_biases[:split_point] | |
right_biases = eqn_biases[split_point:] | |
left_side = eqn_to_str([left_ops, left_vars, left_weights, left_biases]) | |
right_side = eqn_to_str([right_ops, right_vars, right_weights, right_biases]) | |
left_is_float = False | |
right_is_float = False | |
left_value = np.nan | |
right_value = np.nan | |
if is_float(left_side): | |
left_value = float(left_side) | |
left_is_float = True | |
if is_float(right_side): | |
right_value = float(right_side) | |
right_is_float = True | |
if current_op == 'id': | |
return left_side | |
if current_op == 'sqrt': | |
if left_is_float: | |
return "{:.3f}".format(np.sqrt(np.abs(left_value))) | |
return "sqrt({})".format(left_side) | |
if current_op == 'log': | |
if left_is_float: | |
return "{:.3f}".format(np.math.log(safe_abs(left_value))) | |
return "log({})".format(left_side) | |
if current_op == 'sin': | |
if left_is_float: | |
return "{:.3f}".format(np.sin(left_value)) | |
return "sin({})".format(left_side) | |
if current_op == 'exp': | |
if left_is_float: | |
return "{:.3f}".format(np.exp(left_value)) | |
return "exp({})".format(left_side) | |
if current_op == 'add': | |
if left_is_float and right_is_float: | |
return "{:.3f}".format(left_value + right_value) | |
return "({} + {})".format(left_side, right_side) | |
if current_op == 'mul': | |
if left_is_float and right_is_float: | |
return "{:.3f}".format(left_value * right_value) | |
return "({} * {})".format(left_side, right_side) | |
if current_op == 'sub': | |
if left_is_float and right_is_float: | |
return "{:.3f}".format(left_value - right_value) | |
return "({} - {})".format(left_side, right_side) | |
if current_op == 'div': | |
if left_is_float and right_is_float: | |
return "{:.3f}".format(safe_div(left_value, right_value)) | |
return "({} / {})".format(left_side, right_side) | |
return None | |
def simple_eqn_to_str(eqn_as_list, var_y_index=9999): | |
return simplify_formula(eqn_to_str(eqn_as_list, var_y_index=var_y_index)) | |
def get_samples(n_train, n_batch, train_x, train_y): | |
both_samples = np.random.choice(n_train, size=2*n_batch, replace=False) | |
sample = both_samples[:n_batch] | |
valid_sample = both_samples[n_batch:] | |
mini_batch_train_data_x = train_x[sample][:][:] | |
mini_batch_train_data_y = train_y[sample][:][:] | |
mini_valid_sample_x = train_x[valid_sample][:][:] | |
mini_valid_sample_y = train_y[valid_sample][:][:] | |
return mini_batch_train_data_x, mini_batch_train_data_y, mini_valid_sample_x, mini_valid_sample_y | |
############################################### | |
# | |
# Plotting functions | |
# | |
############################################### | |
def plot_1d_curve(train_x, train_y_true, train_y_pred, test_x, test_y_true, test_y_pred, | |
title="", file_suffix="", show_ground_truth=True): | |
plt.figure() | |
plt.title(title) | |
if show_ground_truth: | |
plt.scatter(train_x, train_y_true, color='gray', alpha=0.5, marker='.', label='Ground truth') | |
if test_x is not None: | |
plt.scatter(test_x, test_y_true, color='gray', alpha=0.5, marker='.') | |
if test_x is not None: | |
plt.scatter(test_x, test_y_pred, color='red', alpha=0.7, marker='.', label='Model (test set)') | |
plt.scatter(train_x, train_y_pred, color='blue', alpha=0.7, marker='.', label='Model (train set)') | |
for xc in Settings.train_scope: | |
plt.axvline(x=xc, color='k', linestyle='dashed', linewidth=2) | |
plt.xlabel("x") | |
plt.ylabel("y") | |
plt.legend() | |
plt.savefig("images/true_vs_pred_curve{}.png".format(file_suffix)) | |
plt.close() | |
def plot_2d_curve(x_1, x_2, y, g, | |
title=""): | |
fig = plt.figure(figsize=(11, 5)) | |
# ax = fig.gca(projection='3d') | |
ax = fig.add_subplot(1, 2, 1, projection='3d') | |
plt.title("Learned function") | |
surf = ax.plot_surface(x_1, x_2, y, cmap=cm.coolwarm, | |
linewidth=0, antialiased=False, label="Ground truth") | |
plt.xlabel("x") | |
plt.ylabel("y") | |
fig.colorbar(surf, shrink=0.5, aspect=5) | |
ax = fig.add_subplot(1, 2, 2, projection='3d') | |
plt.title("Residual (g)") | |
surf = ax.plot_surface(x_1, x_2, g, cmap=cm.coolwarm, | |
linewidth=0, antialiased=False, label="Ground truth") | |
plt.xlabel("x") | |
plt.ylabel("y") | |
fig.colorbar(surf, shrink=0.5, aspect=5) | |
# plt.show() | |
# plt.legend() | |
plt.savefig("images/pred_g_2d.png") | |
plt.close() | |
def plot_predicted_vs_actual(pred_y_train, true_y_train, | |
pred_y_test=None, true_y_test=None, | |
model_name="Model", set_name="", show=False): | |
plt.figure() | |
if set_name != "": | |
set_name = "({})".format(set_name) | |
plt.title('{}: Predicted vs. Actual {}'.format(model_name, set_name)) | |
plt.scatter(true_y_train, true_y_train, color='gray', alpha=0.5, marker='.', label='Ground truth') | |
if pred_y_test is not None: | |
plt.scatter(true_y_test, true_y_test, color='gray', alpha=0.5, marker='.') | |
if pred_y_test is not None: | |
plt.scatter(true_y_test, pred_y_test, color="red", alpha=0.6, marker='.', label="{}: Test".format(model_name)) | |
plt.scatter(true_y_train, pred_y_train, color="blue", alpha=0.7, marker='.', label="{}: Train".format(model_name)) | |
plt.ylabel("Observed") | |
plt.xlabel("Expected") | |
plt.legend() | |
plt.savefig("images/single_predicted_vs_actual.png") | |
if show: | |
plt.show() | |
plt.close() | |
# acc_logs: list of accuracy logs to be plotted. | |
def plot_accuracy_over_time(iters_log, acc_logs, error_types): | |
if len(iters_log) < 2: | |
return | |
for i in range(len(acc_logs)): | |
plt.plot(iters_log[1:], acc_logs[i][1:], label=error_types[i]) | |
plt.xlabel("Iteration") | |
plt.ylabel("Error Scores") | |
plt.yscale('log') | |
plt.legend() | |
plt.savefig("images/accuracy_log.png") | |
plt.close() | |
def plot_hist_of_errors(lists_of_error_scores, all_models, num_trials): | |
plt.figure() | |
plt.hist([np.log(errors_i) for errors_i in lists_of_error_scores], | |
label=[model.short_name for model in all_models]) | |
plt.legend() | |
plt.title("Comparing errors of all methods over {} equations".format(num_trials)) | |
plt.xlabel("Log of error") | |
plt.ylabel("Frequency") | |
plt.savefig("images/hist_of_errors.png") | |
plt.close() | |