jadechoghari commited on
Commit
4e05b75
1 Parent(s): 6698379

Create modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +165 -0
modeling.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from transformers import AutoConfig, AutoModelForCausalLM, \
22
+ LlamaConfig, LlamaModel, LlamaForCausalLM
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ #TODO: add this path to hf repo
28
+ from .ferret_arch import FerretMetaModel, FerretMetaForCausalLM
29
+
30
+
31
+ class FerretConfig(LlamaConfig):
32
+ model_type = "ferret_llama"
33
+
34
+
35
+ class FerretLlamaModel(FerretMetaModel, LlamaModel):
36
+ config_class = FerretConfig
37
+
38
+ def __init__(self, config: LlamaConfig):
39
+ super(FerretLlamaModel, self).__init__(config)
40
+
41
+
42
+ class FerretLlamaForCausalLM(LlamaForCausalLM, FerretMetaForCausalLM):
43
+ config_class = FerretConfig
44
+
45
+ def __init__(self, config):
46
+ super(LlamaForCausalLM, self).__init__(config)
47
+ self.model = FerretLlamaModel(config)
48
+ self.pretraining_tp = config.pretraining_tp
49
+ self.vocab_size = config.vocab_size
50
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
51
+
52
+ # Initialize weights and apply final processing
53
+ self.post_init()
54
+
55
+ def get_model(self):
56
+ return self.model
57
+
58
+ def forward(
59
+ self,
60
+ input_ids: torch.LongTensor = None,
61
+ attention_mask: Optional[torch.Tensor] = None,
62
+ position_ids: Optional[torch.LongTensor] = None,
63
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
64
+ inputs_embeds: Optional[torch.FloatTensor] = None,
65
+ labels: Optional[torch.LongTensor] = None,
66
+ use_cache: Optional[bool] = None,
67
+ output_attentions: Optional[bool] = None,
68
+ output_hidden_states: Optional[bool] = None,
69
+ images: Optional[torch.FloatTensor] = None,
70
+ image_sizes: Optional[List[List[int]]] = None,
71
+ region_masks: Optional[List[torch.Tensor]] = None,
72
+ return_dict: Optional[bool] = None,
73
+ cache_position: Optional[torch.LongTensor] = None,
74
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
75
+
76
+ if inputs_embeds is None:
77
+ (
78
+ input_ids,
79
+ position_ids,
80
+ attention_mask,
81
+ past_key_values,
82
+ inputs_embeds,
83
+ labels
84
+ ) = self.prepare_inputs_labels_for_multimodal(
85
+ input_ids,
86
+ position_ids,
87
+ attention_mask,
88
+ past_key_values,
89
+ labels,
90
+ images,
91
+ image_sizes=image_sizes,
92
+ region_masks=region_masks,
93
+ )
94
+
95
+ return super().forward(
96
+ input_ids=input_ids,
97
+ attention_mask=attention_mask,
98
+ position_ids=position_ids,
99
+ past_key_values=past_key_values,
100
+ inputs_embeds=inputs_embeds,
101
+ labels=labels,
102
+ use_cache=use_cache,
103
+ output_attentions=output_attentions,
104
+ output_hidden_states=output_hidden_states,
105
+ return_dict=return_dict,
106
+ cache_position=cache_position,
107
+ )
108
+
109
+ @torch.no_grad()
110
+ def generate(
111
+ self,
112
+ inputs: Optional[torch.Tensor] = None,
113
+ images: Optional[torch.Tensor] = None,
114
+ image_sizes: Optional[torch.Tensor] = None,
115
+ region_masks: Optional[List[torch.Tensor]] = None,
116
+ **kwargs,
117
+ ) -> Union[GenerateOutput, torch.LongTensor]:
118
+ position_ids = kwargs.pop("position_ids", None)
119
+ attention_mask = kwargs.pop("attention_mask", None)
120
+ if "inputs_embeds" in kwargs:
121
+ raise NotImplementedError("`inputs_embeds` is not supported")
122
+
123
+ if images is not None:
124
+ (
125
+ inputs,
126
+ position_ids,
127
+ attention_mask,
128
+ _,
129
+ inputs_embeds,
130
+ _
131
+ ) = self.prepare_inputs_labels_for_multimodal(
132
+ inputs,
133
+ position_ids,
134
+ attention_mask,
135
+ None,
136
+ None,
137
+ images,
138
+ image_sizes=image_sizes,
139
+ region_masks=region_masks,
140
+ )
141
+ else:
142
+ inputs_embeds = self.get_model().embed_tokens(inputs)
143
+
144
+ return super().generate(
145
+ position_ids=position_ids,
146
+ attention_mask=attention_mask,
147
+ inputs_embeds=inputs_embeds,
148
+ **kwargs
149
+ )
150
+
151
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
152
+ inputs_embeds=None, **kwargs):
153
+ images = kwargs.pop("images", None)
154
+ image_sizes = kwargs.pop("image_sizes", None)
155
+ inputs = super().prepare_inputs_for_generation(
156
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
157
+ )
158
+ if images is not None:
159
+ inputs['images'] = images
160
+ if image_sizes is not None:
161
+ inputs['image_sizes'] = image_sizes
162
+ return inputs
163
+
164
+ AutoConfig.register("ferret_llama", FerretConfig)
165
+ AutoModelForCausalLM.register(FerretConfig, FerretLlamaForCausalLM)