|
import torch |
|
|
|
|
|
def unet_add_concat_conds(unet, new_channels=4): |
|
with torch.no_grad(): |
|
new_conv_in = torch.nn.Conv2d(4 + new_channels, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) |
|
new_conv_in.weight.zero_() |
|
new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) |
|
new_conv_in.bias = unet.conv_in.bias |
|
unet.conv_in = new_conv_in |
|
|
|
unet_original_forward = unet.forward |
|
|
|
def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs): |
|
cross_attention_kwargs = {k: v for k, v in kwargs['cross_attention_kwargs'].items()} |
|
c_concat = cross_attention_kwargs.pop('concat_conds') |
|
kwargs['cross_attention_kwargs'] = cross_attention_kwargs |
|
|
|
c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0).to(sample) |
|
new_sample = torch.cat([sample, c_concat], dim=1) |
|
return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs) |
|
|
|
unet.forward = hooked_unet_forward |
|
return |
|
|