File size: 1,390 Bytes
2d3e480 46439a4 2d3e480 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import jax.numpy as jnp
import jax
import torch
from dataclasses import dataclass
import sympy
import sympy as sp
from sympy import Matrix, Symbol
import math
from sde_redefined_param import SDEDimension
@dataclass
class SDEParameterizedBaseLineConfig:
name = "Custom"
variable = Symbol('t', nonnegative=True, real=True)
drift_dimension = SDEDimension.SCALAR
diffusion_dimension = SDEDimension.SCALAR
diffusion_matrix_dimension = SDEDimension.SCALAR
# TODO (KLAUS): HANDLE THE PARAMETERS BEING Ø
drift_parameters = Matrix([sympy.symbols("f1", real=True)])
diffusion_parameters = Matrix([sympy.symbols("sigma_min sigma_max", real=True)])
drift = 0
sigma_min = sympy.Abs(diffusion_parameters[0]) #0.002
sigma_max = sympy.Abs(diffusion_parameters[1]) #80
diffusion = sigma_min * (sigma_max/sigma_min)**variable * sympy.sqrt(2 * sympy.Abs(sympy.log(sigma_max/sigma_min)))
# TODO (KLAUS) : in the SDE SAMPLING CHANGING Q impacts how we sample z ~ N(0, Q*(delta t))
diffusion_matrix = 1
initial_variable_value = 0
max_variable_value = 1 # math.inf
min_sample_value = 0
module = 'jax'
drift_integral_form=False
diffusion_integral_form=False
diffusion_integral_decomposition = 'cholesky' # ldl
non_symbolic_parameters = {'diffusion': torch.tensor([0.002, 80.])}
target = "epsilon" # x0
|