Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2022 Huawei Technologies Co., Ltd. | |
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode | |
# | |
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license | |
def get_schedule(t_T, t_0, n_sample, n_steplength, debug=0): | |
if n_steplength > 1: | |
if not n_sample > 1: | |
raise RuntimeError('n_steplength has no effect if n_sample=1') | |
t = t_T | |
times = [t] | |
while t >= 0: | |
t = t - 1 | |
times.append(t) | |
n_steplength_cur = min(n_steplength, t_T - t) | |
for _ in range(n_sample - 1): | |
for _ in range(n_steplength_cur): | |
t = t + 1 | |
times.append(t) | |
for _ in range(n_steplength_cur): | |
t = t - 1 | |
times.append(t) | |
_check_times(times, t_0, t_T) | |
if debug == 2: | |
for x in [list(range(0, 50)), list(range(-1, -50, -1))]: | |
_plot_times(x=x, times=[times[i] for i in x]) | |
return times | |
def _check_times(times, t_0, t_T): | |
# Check end | |
assert times[0] > times[1], (times[0], times[1]) | |
# Check beginning | |
assert times[-1] == -1, times[-1] | |
# Steplength = 1 | |
for t_last, t_cur in zip(times[:-1], times[1:]): | |
assert abs(t_last - t_cur) == 1, (t_last, t_cur) | |
# Value range | |
for t in times: | |
assert t >= t_0, (t, t_0) | |
assert t <= t_T, (t, t_T) | |
def _plot_times(x, times): | |
import matplotlib.pyplot as plt | |
plt.plot(x, times) | |
plt.show() | |
def get_schedule_jump(t_T, n_sample, jump_length, jump_n_sample, | |
jump2_length=1, jump2_n_sample=1, | |
jump3_length=1, jump3_n_sample=1, | |
start_resampling=100000000): | |
jumps = {} | |
for j in range(0, t_T - jump_length, jump_length): | |
jumps[j] = jump_n_sample - 1 | |
jumps2 = {} | |
for j in range(0, t_T - jump2_length, jump2_length): | |
jumps2[j] = jump2_n_sample - 1 | |
jumps3 = {} | |
for j in range(0, t_T - jump3_length, jump3_length): | |
jumps3[j] = jump3_n_sample - 1 | |
t = t_T | |
ts = [] | |
while t >= 1: | |
t = t-1 | |
ts.append(t) | |
if ( | |
t + 1 < t_T - 1 and | |
t <= start_resampling | |
): | |
for _ in range(n_sample - 1): | |
t = t + 1 | |
ts.append(t) | |
if t >= 0: | |
t = t - 1 | |
ts.append(t) | |
if ( | |
jumps3.get(t, 0) > 0 and | |
t <= start_resampling - jump3_length | |
): | |
jumps3[t] = jumps3[t] - 1 | |
for _ in range(jump3_length): | |
t = t + 1 | |
ts.append(t) | |
if ( | |
jumps2.get(t, 0) > 0 and | |
t <= start_resampling - jump2_length | |
): | |
jumps2[t] = jumps2[t] - 1 | |
for _ in range(jump2_length): | |
t = t + 1 | |
ts.append(t) | |
jumps3 = {} | |
for j in range(0, t_T - jump3_length, jump3_length): | |
jumps3[j] = jump3_n_sample - 1 | |
if ( | |
jumps.get(t, 0) > 0 and | |
t <= start_resampling - jump_length | |
): | |
jumps[t] = jumps[t] - 1 | |
for _ in range(jump_length): | |
t = t + 1 | |
ts.append(t) | |
jumps2 = {} | |
for j in range(0, t_T - jump2_length, jump2_length): | |
jumps2[j] = jump2_n_sample - 1 | |
jumps3 = {} | |
for j in range(0, t_T - jump3_length, jump3_length): | |
jumps3[j] = jump3_n_sample - 1 | |
ts.append(-1) | |
_check_times(ts, -1, t_T) | |
return ts | |
def get_schedule_jump_paper(): | |
t_T = 250 | |
jump_length = 10 | |
jump_n_sample = 10 | |
jumps = {} | |
for j in range(0, t_T - jump_length, jump_length): | |
jumps[j] = jump_n_sample - 1 | |
t = t_T | |
ts = [] | |
while t >= 1: | |
t = t-1 | |
ts.append(t) | |
if jumps.get(t, 0) > 0: | |
jumps[t] = jumps[t] - 1 | |
for _ in range(jump_length): | |
t = t + 1 | |
ts.append(t) | |
ts.append(-1) | |
_check_times(ts, -1, t_T) | |
return ts | |
def get_schedule_jump_test(to_supplement=False): | |
ts = get_schedule_jump(t_T=250, n_sample=1, | |
jump_length=10, jump_n_sample=10, | |
jump2_length=1, jump2_n_sample=1, | |
jump3_length=1, jump3_n_sample=1, | |
start_resampling=250) | |
import matplotlib.pyplot as plt | |
SMALL_SIZE = 8*3 | |
MEDIUM_SIZE = 10*3 | |
BIGGER_SIZE = 12*3 | |
plt.rc('font', size=SMALL_SIZE) # controls default text sizes | |
plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title | |
plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels | |
plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels | |
plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels | |
plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize | |
plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title | |
plt.plot(ts) | |
fig = plt.gcf() | |
fig.set_size_inches(20, 10) | |
ax = plt.gca() | |
ax.set_xlabel('Number of Transitions') | |
ax.set_ylabel('Diffusion time $t$') | |
fig.tight_layout() | |
if to_supplement: | |
out_path = "/cluster/home/alugmayr/gdiff/paper/supplement/figures/jump_sched.pdf" | |
plt.savefig(out_path) | |
out_path = "./schedule.png" | |
plt.savefig(out_path) | |
print(out_path) | |
def main(): | |
get_schedule_jump_test() | |
if __name__ == "__main__": | |
main() |