File size: 6,041 Bytes
24d5d7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import dataclasses
from enum import Enum
from typing import Any, Dict, List, Optional

import transformers


@dataclasses.dataclass
class LoraConfigSimplified:
    """
    Low Rank Approximation (LoRA) configuration.

    Used for language and audio models separately.
    """

    # The rank of the approximation
    r: int = 0
    lora_alpha: float = 8
    target_modules: Optional[List[str]] = dataclasses.field(
        default_factory=lambda: ["k_proj", "q_proj", "linear_k", "linear_q"]
    )


class LossFunction(str, Enum):
    CrossEntropy = "ce"
    KL_Divergence = "kl"


@dataclasses.dataclass
class LossConfig:
    loss_function: LossFunction = LossFunction.CrossEntropy
    kl_temperature: float = 2.0

    @property
    def requires_alt_fields(self):
        return self.loss_function == LossFunction.KL_Divergence


class UltravoxConfig(transformers.PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`UltravoxForConditionalGeneration`]. It is used to instantiate an
    Ultravox model according to the specified arguments, defining the model architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        audio_config (`Wav2Vec2Config`,  *optional*):
            Custom audio config or dict
        text_config (`Union[AutoConfig, dict]`, *optional*):
            The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
        ignore_index (`int`, *optional*, defaults to -100):
            The ignore index for the loss function.
        audio_token_index (`int`, *optional*, defaults to 32000):
            The audio token index to encode the audio prompt.
        stack_factor (`int`, *optional*, defaults to 8):
            Audio downsampling factor for the multimodal projector.
        norm_init (`float`, *optional*, defaults to 0.4):
            The initialization value for the layer normalization.
        projector_act (`str`, *optional*, defaults to `"swiglu"`):
            The activation function used by the multimodal projector.
        text_model_lora_config (`LoraConfigSimplified`, *optional*):
            The LoRA configuration for finetuning the text model.
        audio_model_lora_config (`LoraConfigSimplified`, *optional*):
            The LoRA configuration for finetuning the audio model.


    Example:

    ```python
    >>> from transformers import UltravoxForConditionalGeneration, Wav2Vec2Config, UltravoxConfig, LlamaConfig

    >>> # Initializing an audio encoder config
    >>> audio_config = Wav2Vec2Config()

    >>> # Initializing a Llama config
    >>> text_config = LlamaConfig()

    >>> # Initializing a default configuration
    >>> configuration = UltravoxConfig(audio_config, text_config)

    >>> # Initializing a completely untrained model from the configuration
    >>> model = UltravoxForConditionalGeneration(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config

    >>> # Initialize a model from pretrained checkpoints and random projector weights
    >>> config = UltravoxConfig(audio_model_id="facebook/wav2vec2-base-960h", text_model_id="meta-llama/Llama-2-7b-chat-hf")
    ```"""

    model_type = "ultravox"
    is_composition = False

    def __init__(
        self,
        audio_config: Optional[Dict[str, Any]] = None,
        text_config: Optional[Dict[str, Any]] = None,
        audio_model_id: Optional[str] = None,
        text_model_id: Optional[str] = None,
        ignore_index: int = -100,
        hidden_size: int = 4096,
        stack_factor: int = 8,
        norm_init: float = 0.4,
        projector_act: str = "swiglu",
        text_model_lora_config: Optional[LoraConfigSimplified] = None,
        audio_model_lora_config: Optional[LoraConfigSimplified] = None,
        **kwargs,
    ):
        self.ignore_index = ignore_index

        self.audio_model_id = audio_model_id
        self.text_model_id = text_model_id

        self.hidden_size = hidden_size
        self.stack_factor = stack_factor
        self.norm_init = norm_init
        self.projector_act = projector_act

        if text_model_id is not None:
            self.text_config: transformers.LlamaConfig = (
                transformers.AutoConfig.from_pretrained(text_model_id)
            )
        else:
            text_config = text_config or {}
            self.text_config = transformers.CONFIG_MAPPING[
                text_config.get("model_type", "llama")
            ](**text_config)

        if audio_model_id is not None:
            self.audio_config: transformers.PretrainedConfig = (
                transformers.AutoConfig.from_pretrained(audio_model_id)
            )
        else:
            audio_config = audio_config or {}
            self.audio_config = transformers.CONFIG_MAPPING[
                audio_config.get("model_type", "wav2vec2")
            ](**audio_config)

        self.text_model_lora_config = (
            text_model_lora_config
            if isinstance(text_model_lora_config, dict)
            else dataclasses.asdict(text_model_lora_config or LoraConfigSimplified())
        )
        self.audio_model_lora_config = (
            audio_model_lora_config
            if isinstance(audio_model_lora_config, dict)
            else dataclasses.asdict(audio_model_lora_config or LoraConfigSimplified())
        )

        self.vocab_size = self.text_config.vocab_size

        self.initializer_range = self.text_config.initializer_range

        super().__init__(**kwargs)

    def to_diff_dict(self) -> Dict[str, Any]:
        diff_dict = super().to_diff_dict()

        # remove text_config and audio_config if text_model_id and audio_model_id are present
        if self.text_model_id is not None:
            diff_dict.pop("text_config", None)
        if self.audio_model_id is not None:
            diff_dict.pop("audio_config", None)

        return diff_dict