File size: 22,664 Bytes
373af33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
import random
import clip
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch import Tensor
from typing import List, Dict, Optional, Union

from mogen.models.utils.misc import zero_module

from ..builder import SUBMODULES, build_attention
from .motion_transformer import MotionTransformer


class FFN(nn.Module):
    """
    Feed-forward network (FFN) used in the transformer layers. 
    It consists of two linear layers with a GELU activation in between.

    Args:
        latent_dim (int): Input dimension of the FFN.
        ffn_dim (int): Hidden dimension of the FFN.
        dropout (float): Dropout rate applied after activation.
    """

    def __init__(self, latent_dim: int, ffn_dim: int, dropout: float):
        super().__init__()
        self.linear1 = nn.Linear(latent_dim, ffn_dim)
        self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim))
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: Tensor, **kwargs) -> Tensor:
        """
        Forward pass for the FFN.

        Args:
            x (Tensor): Input tensor of shape (B, T, D).

        Returns:
            Tensor: Output tensor after the FFN, of shape (B, T, D).
        """
        y = self.linear2(self.dropout(self.activation(self.linear1(x))))
        y = x + y
        return y


class EncoderLayer(nn.Module):
    """
    Encoder layer consisting of self-attention and feed-forward network.

    Args:
        sa_block_cfg (Optional[dict]): Configuration for the self-attention block.
        ca_block_cfg (Optional[dict]): Configuration for the cross-attention block (if applicable).
        ffn_cfg (dict): Configuration for the feed-forward network.
    """

    def __init__(self, sa_block_cfg: Optional[dict] = None, ca_block_cfg: Optional[dict] = None, ffn_cfg: dict = None):
        super().__init__()
        self.sa_block = build_attention(sa_block_cfg)
        self.ffn = FFN(**ffn_cfg)

    def forward(self, **kwargs) -> Tensor:
        """
        Forward pass for the encoder layer.

        Args:
            kwargs: Dictionary containing the input tensor (x) and other related parameters.

        Returns:
            Tensor: Output tensor after the encoder layer.
        """
        if self.sa_block is not None:
            x = self.sa_block(**kwargs)
            kwargs.update({'x': x})
        if self.ffn is not None:
            x = self.ffn(**kwargs)
        return x


class RetrievalDatabase(nn.Module):
    """
    Retrieval database for retrieving motions and text features based on given captions.

    Args:
        num_retrieval (int): Number of retrievals for each caption.
        topk (int): Number of top results to consider.
        retrieval_file (str): Path to the retrieval file containing text, motion, and length data.
        latent_dim (Optional[int]): Dimension of the latent space.
        output_dim (Optional[int]): Output dimension of the retrieved features.
        num_layers (Optional[int]): Number of layers in the text encoder.
        num_motion_layers (Optional[int]): Number of layers in the motion encoder.
        kinematic_coef (Optional[float]): Coefficient for scaling kinematic similarity.
        max_seq_len (Optional[int]): Maximum sequence length.
        num_heads (Optional[int]): Number of attention heads.
        ff_size (Optional[int]): Feed-forward size for the transformer layers.
        stride (Optional[int]): Stride for downsampling motion data.
        sa_block_cfg (Optional[dict]): Configuration for the self-attention block.
        ffn_cfg (Optional[dict]): Configuration for the feed-forward network.
        dropout (Optional[float]): Dropout rate.
    """

    def __init__(self,
                 num_retrieval: int,
                 topk: int,
                 retrieval_file: str,
                 latent_dim: Optional[int] = 512,
                 output_dim: Optional[int] = 512,
                 num_layers: Optional[int] = 2,
                 num_motion_layers: Optional[int] = 4,
                 kinematic_coef: Optional[float] = 0.1,
                 max_seq_len: Optional[int] = 196,
                 num_heads: Optional[int] = 8,
                 ff_size: Optional[int] = 1024,
                 stride: Optional[int] = 4,
                 sa_block_cfg: Optional[dict] = None,
                 ffn_cfg: Optional[dict] = None,
                 dropout: Optional[float] = 0):
        super().__init__()
        self.num_retrieval = num_retrieval
        self.topk = topk
        self.latent_dim = latent_dim
        self.stride = stride
        self.kinematic_coef = kinematic_coef
        self.num_layers = num_layers
        self.num_motion_layers = num_motion_layers
        self.max_seq_len = max_seq_len

        # Load data from the retrieval file
        data = np.load(retrieval_file)
        self.text_features = torch.Tensor(data['text_features'])
        self.captions = data['captions']
        self.motions = data['motions']
        self.m_lengths = data['m_lengths']
        self.clip_seq_features = data['clip_seq_features']
        self.train_indexes = data.get('train_indexes', None)
        self.test_indexes = data.get('test_indexes', None)

        self.latent_dim = latent_dim
        self.output_dim = output_dim
        self.motion_proj = nn.Linear(self.motions.shape[-1], self.latent_dim)
        self.motion_pos_embedding = nn.Parameter(
            torch.randn(max_seq_len, self.latent_dim))
        self.motion_encoder_blocks = nn.ModuleList()

        # Build motion encoder blocks
        for i in range(num_motion_layers):
            self.motion_encoder_blocks.append(
                EncoderLayer(sa_block_cfg=sa_block_cfg, ffn_cfg=ffn_cfg))

        # Transformer for encoding text
        TransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
                                                       nhead=num_heads,
                                                       dim_feedforward=ff_size,
                                                       dropout=dropout,
                                                       activation="gelu")
        self.text_encoder = nn.TransformerEncoder(TransEncoderLayer,
                                                  num_layers=num_layers)
        self.results = {}

    def extract_text_feature(self, text: str, clip_model: nn.Module, device: torch.device) -> Tensor:
        """
        Extract text features from CLIP model.

        Args:
            text (str): Input text caption.
            clip_model (nn.Module): CLIP model for encoding the text.
            device (torch.device): Device for computation.

        Returns:
            Tensor: Extracted text features of shape (1, 512).
        """
        text = clip.tokenize([text], truncate=True).to(device)
        with torch.no_grad():
            text_features = clip_model.encode_text(text)
        return text_features

    def encode_text(self, text: List[str], device: torch.device) -> Tensor:
        """
        Encode text using the CLIP model's text encoder.

        Args:
            text (List[str]): List of input text captions.
            device (torch.device): Device for computation.

        Returns:
            Tensor: Encoded text features of shape (B, T, D).
        """
        with torch.no_grad():
            text = clip.tokenize(text, truncate=True).to(device)
            x = self.clip.token_embedding(text).type(self.clip.dtype)

            x = x + self.clip.positional_embedding.type(self.clip.dtype)
            x = x.permute(1, 0, 2)  # NLD -> LND
            x = self.clip.transformer(x)
            x = self.clip.ln_final(x).type(self.clip.dtype)

        # B, T, D
        xf_out = x.permute(1, 0, 2)
        return xf_out

    def retrieve(self, caption: str, length: int, clip_model: nn.Module, device: torch.device, idx: Optional[int] = None) -> List[int]:
        """
        Retrieve motions and text features based on a given caption.

        Args:
            caption (str): Input text caption.
            length (int): Length of the corresponding motion sequence.
            clip_model (nn.Module): CLIP model for encoding the text.
            device (torch.device): Device for computation.
            idx (Optional[int]): Index for retrieval (if provided).

        Returns:
            List[int]: List of indexes for the retrieved motions.
        """
        value = hash(caption)
        if value in self.results:
            return self.results[value]
        text_feature = self.extract_text_feature(caption, clip_model, device)

        rel_length = torch.LongTensor(self.m_lengths).to(device)
        rel_length = torch.abs(rel_length - length)
        rel_length = rel_length / torch.clamp(rel_length, min=length)
        semantic_score = F.cosine_similarity(self.text_features.to(device),
                                             text_feature)
        kinematic_score = torch.exp(-rel_length * self.kinematic_coef)
        score = semantic_score * kinematic_score
        indexes = torch.argsort(score, descending=True)
        data = []
        cnt = 0
        for idx in indexes:
            caption, m_length = self.captions[idx], self.m_lengths[idx]
            if not self.training or m_length != length:
                cnt += 1
                data.append(idx.item())
                if cnt == self.num_retrieval:
                    self.results[value] = data
                    return data
        assert False

    def generate_src_mask(self, T: int, length: List[int]) -> Tensor:
        """
        Generate source mask for the motion sequences based on the motion lengths.

        Args:
            T (int): Maximum sequence length.
            length (List[int]): List of motion lengths for each sample.

        Returns:
            Tensor: A binary mask tensor of shape (B, T), where `B` is the batch size, 
            and `T` is the maximum sequence length. Mask values are 1 for valid positions 
            and 0 for padded positions.
        """
        B = len(length)
        src_mask = torch.ones(B, T)
        for i in range(B):
            for j in range(length[i], T):
                src_mask[i, j] = 0
        return src_mask

    def forward(self, captions: List[str], lengths: List[int], clip_model: nn.Module, device: torch.device, idx: Optional[List[int]] = None) -> Dict[str, Tensor]:
        """
        Forward pass for retrieving motion sequences and text features.

        Args:
            captions (List[str]): List of input text captions.
            lengths (List[int]): List of corresponding motion lengths.
            clip_model (nn.Module): CLIP model for encoding the text.
            device (torch.device): Device for computation.
            idx (Optional[List[int]]): Optional list of indices for retrieval.

        Returns:
            Dict[str, Tensor]: Dictionary containing retrieved text and motion features.
            - re_text: Retrieved text features of shape (B, num_retrieval, T, D).
            - re_motion: Retrieved motion features of shape (B, num_retrieval, T, D).
            - re_mask: Source mask for the retrieved motion of shape (B, num_retrieval, T).
            - raw_motion: Raw motion features of shape (B, T, motion_dim).
            - raw_motion_length: Motion sequence lengths (before any stride).
            - raw_motion_mask: Raw binary mask for valid motion positions of shape (B, T).
        """
        B = len(captions)
        all_indexes = []
        for b_ix in range(B):
            length = int(lengths[b_ix])
            if idx is None:
                batch_indexes = self.retrieve(captions[b_ix], length, clip_model, device)
            else:
                batch_indexes = self.retrieve(captions[b_ix], length, clip_model, device, idx[b_ix])
            all_indexes.extend(batch_indexes)

        all_indexes = np.array(all_indexes)
        all_motions = torch.Tensor(self.motions[all_indexes]).to(device)
        all_m_lengths = torch.Tensor(self.m_lengths[all_indexes]).long()

        # Generate masks and positional encodings
        T = all_motions.shape[1]
        src_mask = self.generate_src_mask(T, all_m_lengths).to(device)
        raw_src_mask = src_mask.clone()
        re_motion = self.motion_proj(all_motions) + self.motion_pos_embedding.unsqueeze(0)

        for module in self.motion_encoder_blocks:
            re_motion = module(x=re_motion, src_mask=src_mask.unsqueeze(-1))

        re_motion = re_motion.view(B, self.num_retrieval, T, -1).contiguous()
        re_motion = re_motion[:, :, ::self.stride, :].contiguous()  # Apply stride
        src_mask = src_mask[:, ::self.stride].contiguous()
        src_mask = src_mask.view(B, self.num_retrieval, -1).contiguous()

        # Process text sequences
        T = 77  # CLIP's max token length
        all_text_seq_features = torch.Tensor(self.clip_seq_features[all_indexes]).to(device)
        all_text_seq_features = all_text_seq_features.permute(1, 0, 2)
        re_text = self.text_encoder(all_text_seq_features)
        re_text = re_text.permute(1, 0, 2)
        re_text = re_text.view(B, self.num_retrieval, T, -1).contiguous()
        re_text = re_text[:, :, -1:, :].contiguous()  # Use the last token only for each sequence

        re_dict = {
            're_text': re_text,
            're_motion': re_motion,
            're_mask': src_mask,
            'raw_motion': all_motions,
            'raw_motion_length': all_m_lengths,
            'raw_motion_mask': raw_src_mask
        }
        return re_dict


@SUBMODULES.register_module()
class ReMoDiffuseTransformer(MotionTransformer):
    """
    Transformer model for motion retrieval and diffusion.

    Args:
        retrieval_cfg (dict): Configuration for the retrieval database.
        scale_func_cfg (dict): Configuration for scaling functions.
        kwargs: Additional arguments for the base DiffusionTransformer.
    """

    def __init__(self, retrieval_cfg: dict, scale_func_cfg: dict, **kwargs):
        super().__init__(**kwargs)
        self.database = RetrievalDatabase(**retrieval_cfg)
        self.scale_func_cfg = scale_func_cfg

    def scale_func(self, timestep: int) -> Dict[str, float]:
        """
        Scale function for adjusting the guidance between text and retrieval.

        Args:
            timestep (int): Current diffusion timestep.

        Returns:
            Dict[str, float]: Scaling coefficients for different guidance types.
            - both_coef: Coefficient for both text and retrieval guidance.
            - text_coef: Coefficient for text-only guidance.
            - retr_coef: Coefficient for retrieval-only guidance.
            - none_coef: Coefficient for no guidance.
        """
        coarse_scale = self.scale_func_cfg['coarse_scale']
        w = (1 - (1000 - timestep) / 1000) * coarse_scale + 1
        if timestep > 100:
            if random.randint(0, 1) == 0:
                output = {
                    'both_coef': w,
                    'text_coef': 0,
                    'retr_coef': 1 - w,
                    'none_coef': 0
                }
            else:
                output = {
                    'both_coef': 0,
                    'text_coef': w,
                    'retr_coef': 0,
                    'none_coef': 1 - w
                }
        else:
            both_coef = self.scale_func_cfg['both_coef']
            text_coef = self.scale_func_cfg['text_coef']
            retr_coef = self.scale_func_cfg['retr_coef']
            none_coef = 1 - both_coef - text_coef - retr_coef
            output = {
                'both_coef': both_coef,
                'text_coef': text_coef,
                'retr_coef': retr_coef,
                'none_coef': none_coef
            }
        return output

    def get_precompute_condition(self, 
                                 text: Optional[str] = None,
                                 motion_length: Optional[Tensor] = None,
                                 xf_out: Optional[Tensor] = None,
                                 re_dict: Optional[Dict] = None,
                                 device: Optional[torch.device] = None,
                                 sample_idx: Optional[Tensor] = None,
                                 clip_feat: Optional[Tensor] = None,
                                 **kwargs) -> Dict[str, Union[Tensor, Dict]]:
        """
        Precompute conditions for both text and retrieval-guided diffusion.

        Args:
            text (Optional[str]): Input text string for guidance.
            motion_length (Optional[Tensor]): Lengths of the motion sequences.
            xf_out (Optional[Tensor]): Encoded text feature (if precomputed).
            re_dict (Optional[Dict]): Dictionary of retrieval results (if precomputed).
            device (Optional[torch.device]): Device to perform computation on.
            sample_idx (Optional[Tensor]): Sample indices for retrieval.
            clip_feat (Optional[Tensor]): Clip features (if used).

        Returns:
            Dict[str, Union[Tensor, Dict]]: Dictionary containing encoded features and retrieval results.
        """
        if xf_out is None:
            xf_out = self.encode_text(text, clip_feat, device)
        output = {'xf_out': xf_out}
        if re_dict is None:
            re_dict = self.database(text, motion_length, self.clip, device, idx=sample_idx)
        output['re_dict'] = re_dict
        return output

    def post_process(self, motion: Tensor) -> Tensor:
        """
        Post-process the generated motion by normalizing or un-normalizing it.

        Args:
            motion (Tensor): Generated motion data.

        Returns:
            Tensor: Post-processed motion data.
        """
        if self.post_process_cfg is not None:
            if self.post_process_cfg.get("unnormalized_infer", False):
                mean = torch.from_numpy(np.load(self.post_process_cfg['mean_path'])).type_as(motion)
                std = torch.from_numpy(np.load(self.post_process_cfg['std_path'])).type_as(motion)
                motion = motion * std + mean
        return motion

    def forward_train(self,
                      h: Tensor,
                      src_mask: Tensor,
                      emb: Tensor,
                      xf_out: Optional[Tensor] = None,
                      re_dict: Optional[Dict] = None,
                      **kwargs) -> Tensor:
        """
        Forward training pass for motion retrieval and diffusion model.

        Args:
            h (Tensor): Input motion features of shape (B, T, D).
            src_mask (Tensor): Mask for the motion data of shape (B, T, 1).
            emb (Tensor): Embedding tensor for timesteps.
            xf_out (Optional[Tensor]): Precomputed text features.
            re_dict (Optional[Dict]): Dictionary of retrieval features.

        Returns:
            Tensor: Output motion data of shape (B, T, D).
        """
        B, T = h.shape[0], h.shape[1]
        cond_type = torch.randint(0, 100, size=(B, 1, 1)).to(h.device)
        for module in self.temporal_decoder_blocks:
            h = module(x=h,
                       xf=xf_out,
                       emb=emb,
                       src_mask=src_mask,
                       cond_type=cond_type,
                       re_dict=re_dict)

        output = self.out(h).view(B, T, -1).contiguous()
        return output

    def forward_test(self,
                     h: Tensor,
                     src_mask: Tensor,
                     emb: Tensor,
                     xf_out: Optional[Tensor] = None,
                     re_dict: Optional[Dict] = None,
                     timesteps: Optional[Tensor] = None,
                     **kwargs) -> Tensor:
        """
        Forward testing pass for motion retrieval and diffusion model. This method handles
        multiple conditional types such as both text and retrieval-based guidance.

        Args:
            h (Tensor): Input motion features of shape (B, T, D).
            src_mask (Tensor): Mask for the motion data of shape (B, T, 1).
            emb (Tensor): Embedding tensor for timesteps.
            xf_out (Optional[Tensor]): Precomputed text features.
            re_dict (Optional[Dict]): Dictionary of retrieval features.
            timesteps (Optional[Tensor]): Tensor containing current timesteps in the diffusion process.

        Returns:
            Tensor: Output motion data after applying multiple conditional types, of shape (B, T, D).
        """
        B, T = h.shape[0], h.shape[1]
        
        # Define condition types for different guidance types
        both_cond_type = torch.zeros(B, 1, 1).to(h.device) + 99
        text_cond_type = torch.zeros(B, 1, 1).to(h.device) + 1
        retr_cond_type = torch.zeros(B, 1, 1).to(h.device) + 10
        none_cond_type = torch.zeros(B, 1, 1).to(h.device)

        # Concatenate all conditional types and repeat inputs for different guidance modes
        all_cond_type = torch.cat((both_cond_type, text_cond_type, retr_cond_type, none_cond_type), dim=0)
        h = h.repeat(4, 1, 1)
        xf_out = xf_out.repeat(4, 1, 1)
        emb = emb.repeat(4, 1)
        src_mask = src_mask.repeat(4, 1, 1)

        # Repeat retrieval features if necessary
        if re_dict['re_motion'].shape[0] != h.shape[0]:
            re_dict['re_motion'] = re_dict['re_motion'].repeat(4, 1, 1, 1)
            re_dict['re_text'] = re_dict['re_text'].repeat(4, 1, 1, 1)
            re_dict['re_mask'] = re_dict['re_mask'].repeat(4, 1, 1)

        # Pass through the temporal decoder blocks
        for module in self.temporal_decoder_blocks:
            h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask, cond_type=all_cond_type, re_dict=re_dict)

        # Retrieve output features and handle different guidance coefficients
        out = self.out(h).view(4 * B, T, -1).contiguous()
        out_both = out[:B].contiguous()
        out_text = out[B:2 * B].contiguous()
        out_retr = out[2 * B:3 * B].contiguous()
        out_none = out[3 * B:].contiguous()

        # Apply scaling coefficients based on the timestep
        coef_cfg = self.scale_func(int(timesteps[0]))
        both_coef = coef_cfg['both_coef']
        text_coef = coef_cfg['text_coef']
        retr_coef = coef_cfg['retr_coef']
        none_coef = coef_cfg['none_coef']

        # Compute the final output by blending the different guidance outputs
        output = out_both * both_coef
        output += out_text * text_coef
        output += out_retr * retr_coef
        output += out_none * none_coef

        return output