from typing import Union, Optional, List import torch from diffusers.utils import logging from transformers import ( T5EncoderModel, T5TokenizerFast, ) from transformers import ( CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer ) import numpy as np import torch.distributed as dist import math import os logger = logging.get_logger(__name__) # pylint: disable=invalid-name def get_t5_prompt_embeds( tokenizer: T5TokenizerFast , text_encoder: T5EncoderModel, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, max_sequence_length: int = 128, device: Optional[torch.device] = None, ): device = device or text_encoder.device prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = tokenizer( prompt, # padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device))[0] # Concat zeros to max_sequence b, seq_len, dim = prompt_embeds.shape if seq_len torch.Tensor: n_axes = ids.shape[-1] cos_out = [] sin_out = [] pos = ids.float() is_mps = ids.device.type == "mps" freqs_dtype = torch.float32 if is_mps else torch.float64 for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( self.axes_dim[i], pos[:, i], theta=self.theta, repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype, ) cos_out.append(cos) sin_out.append(sin) freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) return freqs_cos, freqs_sin