import jax.numpy as jnp | |
import jax | |
import torch | |
from dataclasses import dataclass | |
import sympy | |
import sympy as sp | |
from sympy import Matrix, Symbol | |
import math | |
from sde_redefined_param import SDEDimension | |
class SDEBaseLineConfig: | |
name = "Custom" | |
variable = Symbol('t', nonnegative=True, real=True) | |
drift_dimension = SDEDimension.SCALAR | |
diffusion_dimension = SDEDimension.SCALAR | |
diffusion_matrix_dimension = SDEDimension.SCALAR | |
# TODO (KLAUS): HANDLE THE PARAMETERS BEING Ø | |
drift_parameters = Matrix([sympy.symbols("f1")]) | |
diffusion_parameters = Matrix([sympy.symbols("l1")]) | |
drift = 0 | |
sigma_min = 0.002 | |
sigma_max = 80 | |
diffusion = sigma_min * (sigma_max/sigma_min)**variable * sympy.sqrt(2 * sympy.log(sigma_max/sigma_min)) | |
# TODO (KLAUS) : in the SDE SAMPLING CHANGING Q impacts how we sample z ~ N(0, Q*(delta t)) | |
diffusion_matrix = 1 | |
initial_variable_value = 0 | |
max_variable_value = 1 # math.inf | |
min_sample_value = 1e-6 | |
module = 'jax' | |
drift_integral_form=False | |
diffusion_integral_form=False | |
diffusion_integral_decomposition = 'cholesky' # ldl | |
target = "epsilon" # x0 |