pokemon-testing / scheduler /scheduler_config.py
AltLuv's picture
End of training
fda6f38
raw
history blame
1.21 kB
import jax.numpy as jnp
import jax
import torch
from dataclasses import dataclass
import sympy
import sympy as sp
from sympy import Matrix, Symbol
import math
from sde_redefined_param import SDEDimension
@dataclass
class SDEConfig:
name = "Custom"
variable = Symbol('t', nonnegative=True, real=True)
drift_dimension = SDEDimension.SCALAR
diffusion_dimension = SDEDimension.SCALAR
diffusion_matrix_dimension = SDEDimension.SCALAR
# TODO (KLAUS): HANDLE THE PARAMETERS BEING Ø
drift_parameters = Matrix([sympy.symbols("f1")])
diffusion_parameters = Matrix([sympy.symbols("l1")])
drift =-variable**2 * drift_parameters[0]**2
k = 1 #* diffusion_parameters[0]**2
diffusion = sympy.Piecewise((k * sympy.sin(variable/2 * sympy.pi), variable < 1), (k*1, variable >= 1))
# TODO (KLAUS) : in the SDE SAMPLING CHANGING Q impacts how we sample z ~ N(0, Q*(delta t))
diffusion_matrix = 1
initial_variable_value = 0
max_variable_value = 1 # math.inf
min_sample_value = 1e-6
module = 'jax'
drift_integral_form=True
diffusion_integral_form=True
diffusion_integral_decomposition = 'cholesky' # ldl
target = "epsilon" # x0