FoundHand / diffusion /scheduler.py
Chaerin5's picture
init
49f816b
raw
history blame
6.27 kB
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
#
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
def get_schedule(t_T, t_0, n_sample, n_steplength, debug=0):
if n_steplength > 1:
if not n_sample > 1:
raise RuntimeError('n_steplength has no effect if n_sample=1')
t = t_T
times = [t]
while t >= 0:
t = t - 1
times.append(t)
n_steplength_cur = min(n_steplength, t_T - t)
for _ in range(n_sample - 1):
for _ in range(n_steplength_cur):
t = t + 1
times.append(t)
for _ in range(n_steplength_cur):
t = t - 1
times.append(t)
_check_times(times, t_0, t_T)
if debug == 2:
for x in [list(range(0, 50)), list(range(-1, -50, -1))]:
_plot_times(x=x, times=[times[i] for i in x])
return times
def _check_times(times, t_0, t_T):
# Check end
assert times[0] > times[1], (times[0], times[1])
# Check beginning
assert times[-1] == -1, times[-1]
# Steplength = 1
for t_last, t_cur in zip(times[:-1], times[1:]):
assert abs(t_last - t_cur) == 1, (t_last, t_cur)
# Value range
for t in times:
assert t >= t_0, (t, t_0)
assert t <= t_T, (t, t_T)
def _plot_times(x, times):
import matplotlib.pyplot as plt
plt.plot(x, times)
plt.show()
def get_schedule_jump(t_T, n_sample, jump_length, jump_n_sample,
jump2_length=1, jump2_n_sample=1,
jump3_length=1, jump3_n_sample=1,
start_resampling=100000000):
jumps = {}
for j in range(0, t_T - jump_length, jump_length):
jumps[j] = jump_n_sample - 1
jumps2 = {}
for j in range(0, t_T - jump2_length, jump2_length):
jumps2[j] = jump2_n_sample - 1
jumps3 = {}
for j in range(0, t_T - jump3_length, jump3_length):
jumps3[j] = jump3_n_sample - 1
t = t_T
ts = []
while t >= 1:
t = t-1
ts.append(t)
if (
t + 1 < t_T - 1 and
t <= start_resampling
):
for _ in range(n_sample - 1):
t = t + 1
ts.append(t)
if t >= 0:
t = t - 1
ts.append(t)
if (
jumps3.get(t, 0) > 0 and
t <= start_resampling - jump3_length
):
jumps3[t] = jumps3[t] - 1
for _ in range(jump3_length):
t = t + 1
ts.append(t)
if (
jumps2.get(t, 0) > 0 and
t <= start_resampling - jump2_length
):
jumps2[t] = jumps2[t] - 1
for _ in range(jump2_length):
t = t + 1
ts.append(t)
jumps3 = {}
for j in range(0, t_T - jump3_length, jump3_length):
jumps3[j] = jump3_n_sample - 1
if (
jumps.get(t, 0) > 0 and
t <= start_resampling - jump_length
):
jumps[t] = jumps[t] - 1
for _ in range(jump_length):
t = t + 1
ts.append(t)
jumps2 = {}
for j in range(0, t_T - jump2_length, jump2_length):
jumps2[j] = jump2_n_sample - 1
jumps3 = {}
for j in range(0, t_T - jump3_length, jump3_length):
jumps3[j] = jump3_n_sample - 1
ts.append(-1)
_check_times(ts, -1, t_T)
return ts
def get_schedule_jump_paper():
t_T = 250
jump_length = 10
jump_n_sample = 10
jumps = {}
for j in range(0, t_T - jump_length, jump_length):
jumps[j] = jump_n_sample - 1
t = t_T
ts = []
while t >= 1:
t = t-1
ts.append(t)
if jumps.get(t, 0) > 0:
jumps[t] = jumps[t] - 1
for _ in range(jump_length):
t = t + 1
ts.append(t)
ts.append(-1)
_check_times(ts, -1, t_T)
return ts
def get_schedule_jump_test(to_supplement=False):
ts = get_schedule_jump(t_T=250, n_sample=1,
jump_length=10, jump_n_sample=10,
jump2_length=1, jump2_n_sample=1,
jump3_length=1, jump3_n_sample=1,
start_resampling=250)
import matplotlib.pyplot as plt
SMALL_SIZE = 8*3
MEDIUM_SIZE = 10*3
BIGGER_SIZE = 12*3
plt.rc('font', size=SMALL_SIZE) # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title
plt.plot(ts)
fig = plt.gcf()
fig.set_size_inches(20, 10)
ax = plt.gca()
ax.set_xlabel('Number of Transitions')
ax.set_ylabel('Diffusion time $t$')
fig.tight_layout()
if to_supplement:
out_path = "/cluster/home/alugmayr/gdiff/paper/supplement/figures/jump_sched.pdf"
plt.savefig(out_path)
out_path = "./schedule.png"
plt.savefig(out_path)
print(out_path)
def main():
get_schedule_jump_test()
if __name__ == "__main__":
main()