|
import torch |
|
from typing import Union, List, Optional, Dict, Any, Tuple |
|
from diffusers.models.unet_2d_condition import UNet2DConditionOutput |
|
|
|
from library.original_unet import SampleOutput |
|
|
|
|
|
def unet_forward_XTI( |
|
self, |
|
sample: torch.FloatTensor, |
|
timestep: Union[torch.Tensor, float, int], |
|
encoder_hidden_states: torch.Tensor, |
|
class_labels: Optional[torch.Tensor] = None, |
|
return_dict: bool = True, |
|
) -> Union[Dict, Tuple]: |
|
r""" |
|
Args: |
|
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor |
|
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps |
|
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a dict instead of a plain tuple. |
|
|
|
Returns: |
|
`SampleOutput` or `tuple`: |
|
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default_overall_up_factor = 2**self.num_upsamplers |
|
|
|
|
|
|
|
forward_upsample_size = False |
|
upsample_size = None |
|
|
|
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): |
|
|
|
forward_upsample_size = True |
|
|
|
|
|
timesteps = timestep |
|
timesteps = self.handle_unusual_timesteps(sample, timesteps) |
|
|
|
t_emb = self.time_proj(timesteps) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
t_emb = t_emb.to(dtype=self.dtype) |
|
emb = self.time_embedding(t_emb) |
|
|
|
|
|
sample = self.conv_in(sample) |
|
|
|
|
|
down_block_res_samples = (sample,) |
|
down_i = 0 |
|
for downsample_block in self.down_blocks: |
|
|
|
|
|
if downsample_block.has_cross_attention: |
|
sample, res_samples = downsample_block( |
|
hidden_states=sample, |
|
temb=emb, |
|
encoder_hidden_states=encoder_hidden_states[down_i : down_i + 2], |
|
) |
|
down_i += 2 |
|
else: |
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb) |
|
|
|
down_block_res_samples += res_samples |
|
|
|
|
|
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6]) |
|
|
|
|
|
up_i = 7 |
|
for i, upsample_block in enumerate(self.up_blocks): |
|
is_final_block = i == len(self.up_blocks) - 1 |
|
|
|
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] |
|
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] |
|
|
|
|
|
|
|
if not is_final_block and forward_upsample_size: |
|
upsample_size = down_block_res_samples[-1].shape[2:] |
|
|
|
if upsample_block.has_cross_attention: |
|
sample = upsample_block( |
|
hidden_states=sample, |
|
temb=emb, |
|
res_hidden_states_tuple=res_samples, |
|
encoder_hidden_states=encoder_hidden_states[up_i : up_i + 3], |
|
upsample_size=upsample_size, |
|
) |
|
up_i += 3 |
|
else: |
|
sample = upsample_block( |
|
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size |
|
) |
|
|
|
|
|
sample = self.conv_norm_out(sample) |
|
sample = self.conv_act(sample) |
|
sample = self.conv_out(sample) |
|
|
|
if not return_dict: |
|
return (sample,) |
|
|
|
return SampleOutput(sample=sample) |
|
|
|
|
|
def downblock_forward_XTI( |
|
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None |
|
): |
|
output_states = () |
|
i = 0 |
|
|
|
for resnet, attn in zip(self.resnets, self.attentions): |
|
if self.training and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module, return_dict=None): |
|
def custom_forward(*inputs): |
|
if return_dict is not None: |
|
return module(*inputs, return_dict=return_dict) |
|
else: |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) |
|
hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i] |
|
)[0] |
|
else: |
|
hidden_states = resnet(hidden_states, temb) |
|
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample |
|
|
|
output_states += (hidden_states,) |
|
i += 1 |
|
|
|
if self.downsamplers is not None: |
|
for downsampler in self.downsamplers: |
|
hidden_states = downsampler(hidden_states) |
|
|
|
output_states += (hidden_states,) |
|
|
|
return hidden_states, output_states |
|
|
|
|
|
def upblock_forward_XTI( |
|
self, |
|
hidden_states, |
|
res_hidden_states_tuple, |
|
temb=None, |
|
encoder_hidden_states=None, |
|
upsample_size=None, |
|
): |
|
i = 0 |
|
for resnet, attn in zip(self.resnets, self.attentions): |
|
|
|
res_hidden_states = res_hidden_states_tuple[-1] |
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1] |
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) |
|
|
|
if self.training and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module, return_dict=None): |
|
def custom_forward(*inputs): |
|
if return_dict is not None: |
|
return module(*inputs, return_dict=return_dict) |
|
else: |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) |
|
hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i] |
|
)[0] |
|
else: |
|
hidden_states = resnet(hidden_states, temb) |
|
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample |
|
|
|
i += 1 |
|
|
|
if self.upsamplers is not None: |
|
for upsampler in self.upsamplers: |
|
hidden_states = upsampler(hidden_states, upsample_size) |
|
|
|
return hidden_states |
|
|