teowu commited on
Commit
474d82c
·
1 Parent(s): e17bbc3

Upload modeling_llama2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_llama2.py +486 -0
modeling_llama2.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from functools import partial
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torch.utils.checkpoint
9
+ from torch import nn
10
+
11
+ import transformers
12
+ from transformers.models.llama.modeling_llama import *
13
+ from transformers.configuration_utils import PretrainedConfig
14
+ from transformers.utils import logging
15
+
16
+ from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
17
+ from .configuration_mplug_owl2 import LlamaConfig
18
+
19
+ class MultiwayNetwork(nn.Module):
20
+
21
+ def __init__(self, module_provider, num_multiway=2):
22
+ super(MultiwayNetwork, self).__init__()
23
+
24
+ self.multiway = torch.nn.ModuleList([module_provider() for _ in range(num_multiway)])
25
+
26
+ def forward(self, hidden_states, multiway_indices):
27
+
28
+ if len(self.multiway) == 1:
29
+ return self.multiway[0](hidden_states)
30
+
31
+ output_hidden_states = torch.empty_like(hidden_states)
32
+
33
+ for idx, subway in enumerate(self.multiway):
34
+ local_indices = multiway_indices.eq(idx).nonzero(as_tuple=True)
35
+ hidden = hidden_states[local_indices].unsqueeze(1).contiguous()
36
+ if hidden.numel():
37
+ output = subway(hidden)
38
+ if isinstance(output, tuple):
39
+ output = output[0]
40
+ output = output.squeeze(1)
41
+ output_hidden_states[local_indices] = output
42
+
43
+ return output_hidden_states.contiguous()
44
+
45
+
46
+ class LlamaAttention(nn.Module):
47
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
48
+
49
+ def __init__(self, config: LlamaConfig):
50
+ super().__init__()
51
+ self.config = config
52
+ self.hidden_size = config.hidden_size
53
+ self.num_heads = config.num_attention_heads
54
+ self.head_dim = self.hidden_size // self.num_heads
55
+ self.num_key_value_heads = config.num_key_value_heads
56
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
57
+ self.max_position_embeddings = config.max_position_embeddings
58
+ self.rope_theta = config.rope_theta
59
+
60
+ if (self.head_dim * self.num_heads) != self.hidden_size:
61
+ raise ValueError(
62
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
63
+ f" and `num_heads`: {self.num_heads})."
64
+ )
65
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
66
+ self.k_proj = MultiwayNetwork(module_provider=partial(
67
+ nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
68
+ )
69
+ self.v_proj = MultiwayNetwork(module_provider=partial(
70
+ nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
71
+ )
72
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
73
+ self._init_rope()
74
+
75
+ def _init_rope(self):
76
+ if self.config.rope_scaling is None:
77
+ self.rotary_emb = LlamaRotaryEmbedding(
78
+ self.head_dim,
79
+ max_position_embeddings=self.max_position_embeddings,
80
+ base=self.rope_theta,
81
+ )
82
+ else:
83
+ scaling_type = self.config.rope_scaling["type"]
84
+ scaling_factor = self.config.rope_scaling["factor"]
85
+ if scaling_type == "linear":
86
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
87
+ self.head_dim,
88
+ max_position_embeddings=self.max_position_embeddings,
89
+ scaling_factor=scaling_factor,
90
+ base=self.rope_theta,
91
+ )
92
+ elif scaling_type == "dynamic":
93
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
94
+ self.head_dim,
95
+ max_position_embeddings=self.max_position_embeddings,
96
+ scaling_factor=scaling_factor,
97
+ base=self.rope_theta,
98
+ )
99
+ else:
100
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
101
+
102
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
103
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
104
+
105
+ def forward(
106
+ self,
107
+ hidden_states: torch.Tensor,
108
+ modality_indicators: torch.Tensor,
109
+ attention_mask: Optional[torch.Tensor] = None,
110
+ position_ids: Optional[torch.LongTensor] = None,
111
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
112
+ output_attentions: bool = False,
113
+ use_cache: bool = False,
114
+ padding_mask: Optional[torch.LongTensor] = None,
115
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
116
+ bsz, q_len, _ = hidden_states.size()
117
+
118
+ query_states = self.q_proj(hidden_states, )
119
+ key_states = self.k_proj(hidden_states, modality_indicators)
120
+ value_states = self.v_proj(hidden_states, modality_indicators)
121
+
122
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
123
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
124
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
125
+
126
+ kv_seq_len = key_states.shape[-2]
127
+ if past_key_value is not None:
128
+ kv_seq_len += past_key_value[0].shape[-2]
129
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
130
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
131
+
132
+ if past_key_value is not None:
133
+ # reuse k, v, self_attention
134
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
135
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
136
+
137
+ past_key_value = (key_states, value_states) if use_cache else None
138
+
139
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
140
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
141
+
142
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
143
+
144
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
145
+ raise ValueError(
146
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
147
+ f" {attn_weights.size()}"
148
+ )
149
+
150
+ if attention_mask is not None:
151
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
152
+ raise ValueError(
153
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
154
+ )
155
+ attn_weights = attn_weights + attention_mask
156
+
157
+ # upcast attention to fp32
158
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
159
+ attn_output = torch.matmul(attn_weights, value_states)
160
+
161
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
162
+ raise ValueError(
163
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
164
+ f" {attn_output.size()}"
165
+ )
166
+
167
+ attn_output = attn_output.transpose(1, 2).contiguous()
168
+
169
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
170
+
171
+ attn_output = self.o_proj(attn_output)
172
+
173
+ if not output_attentions:
174
+ attn_weights = None
175
+
176
+ return attn_output, attn_weights, past_key_value
177
+
178
+
179
+
180
+ class LlamaDecoderLayer(nn.Module):
181
+ def __init__(self, config: LlamaConfig):
182
+ super().__init__()
183
+ self.hidden_size = config.hidden_size
184
+ self.self_attn = LlamaAttention(config=config)
185
+ self.mlp = LlamaMLP(config)
186
+ self.input_layernorm = MultiwayNetwork(module_provider=partial(
187
+ LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps
188
+ ))
189
+ self.post_attention_layernorm = MultiwayNetwork(module_provider=partial(
190
+ LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps
191
+ ))
192
+
193
+ def forward(
194
+ self,
195
+ hidden_states: torch.Tensor,
196
+ modality_indicators: torch.Tensor = None,
197
+ attention_mask: Optional[torch.Tensor] = None,
198
+ position_ids: Optional[torch.LongTensor] = None,
199
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
200
+ output_attentions: Optional[bool] = False,
201
+ use_cache: Optional[bool] = False,
202
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
203
+ """
204
+ Args:
205
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
206
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
207
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
208
+ output_attentions (`bool`, *optional*):
209
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
210
+ returned tensors for more detail.
211
+ use_cache (`bool`, *optional*):
212
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
213
+ (see `past_key_values`).
214
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
215
+ """
216
+
217
+ residual = hidden_states
218
+
219
+ hidden_states = self.input_layernorm(hidden_states, modality_indicators)
220
+
221
+ # Self Attention
222
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
223
+ hidden_states=hidden_states,
224
+ modality_indicators=modality_indicators,
225
+ attention_mask=attention_mask,
226
+ position_ids=position_ids,
227
+ past_key_value=past_key_value,
228
+ output_attentions=output_attentions,
229
+ use_cache=use_cache,
230
+ )
231
+ hidden_states = residual + hidden_states
232
+
233
+ # Fully Connected
234
+ residual = hidden_states
235
+ hidden_states = self.post_attention_layernorm(hidden_states, modality_indicators)
236
+ hidden_states = self.mlp(hidden_states)
237
+ hidden_states = residual + hidden_states
238
+
239
+ outputs = (hidden_states,)
240
+
241
+ if output_attentions:
242
+ outputs += (self_attn_weights,)
243
+
244
+ if use_cache:
245
+ outputs += (present_key_value,)
246
+
247
+ return outputs
248
+
249
+
250
+ def model_forward(
251
+ self,
252
+ input_ids: torch.LongTensor = None,
253
+ modality_indicators: torch.Tensor = None,
254
+ attention_mask: Optional[torch.Tensor] = None,
255
+ position_ids: Optional[torch.LongTensor] = None,
256
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
257
+ inputs_embeds: Optional[torch.FloatTensor] = None,
258
+ use_cache: Optional[bool] = None,
259
+ output_attentions: Optional[bool] = None,
260
+ output_hidden_states: Optional[bool] = None,
261
+ return_dict: Optional[bool] = None,
262
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
263
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
264
+ output_hidden_states = (
265
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
266
+ )
267
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
268
+
269
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
270
+
271
+ # retrieve input_ids and inputs_embeds
272
+ if input_ids is not None and inputs_embeds is not None:
273
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
274
+ elif input_ids is not None:
275
+ batch_size, seq_length = input_ids.shape
276
+ elif inputs_embeds is not None:
277
+ batch_size, seq_length, _ = inputs_embeds.shape
278
+ else:
279
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
280
+
281
+ seq_length_with_past = seq_length
282
+ past_key_values_length = 0
283
+
284
+ if past_key_values is not None:
285
+ past_key_values_length = past_key_values[0][0].shape[2]
286
+ seq_length_with_past = seq_length_with_past + past_key_values_length
287
+
288
+ if position_ids is None:
289
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
290
+ position_ids = torch.arange(
291
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
292
+ )
293
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
294
+ else:
295
+ position_ids = position_ids.view(-1, seq_length).long()
296
+
297
+ if inputs_embeds is None:
298
+ inputs_embeds = self.embed_tokens(input_ids)
299
+ # embed positions
300
+ if attention_mask is None:
301
+ attention_mask = torch.ones(
302
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
303
+ )
304
+ attention_mask = self._prepare_decoder_attention_mask(
305
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
306
+ )
307
+
308
+ hidden_states = inputs_embeds
309
+
310
+ if self.gradient_checkpointing and self.training:
311
+ if use_cache:
312
+ logger.warning_once(
313
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
314
+ )
315
+ use_cache = False
316
+
317
+ # decoder layers
318
+ all_hidden_states = () if output_hidden_states else None
319
+ all_self_attns = () if output_attentions else None
320
+ next_decoder_cache = () if use_cache else None
321
+
322
+ for idx, decoder_layer in enumerate(self.layers):
323
+ if output_hidden_states:
324
+ all_hidden_states += (hidden_states,)
325
+
326
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
327
+
328
+ if self.gradient_checkpointing and self.training:
329
+
330
+ def create_custom_forward(module):
331
+ def custom_forward(*inputs):
332
+ # None for past_key_value
333
+ return module(*inputs, past_key_value, output_attentions)
334
+
335
+ return custom_forward
336
+
337
+ layer_outputs = torch.utils.checkpoint.checkpoint(
338
+ create_custom_forward(decoder_layer),
339
+ hidden_states,
340
+ modality_indicators,
341
+ attention_mask,
342
+ position_ids,
343
+ )
344
+ else:
345
+ layer_outputs = decoder_layer(
346
+ hidden_states,
347
+ modality_indicators=modality_indicators,
348
+ attention_mask=attention_mask,
349
+ position_ids=position_ids,
350
+ past_key_value=past_key_value,
351
+ output_attentions=output_attentions,
352
+ use_cache=use_cache,
353
+ )
354
+
355
+ hidden_states = layer_outputs[0]
356
+
357
+ if use_cache:
358
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
359
+
360
+ if output_attentions:
361
+ all_self_attns += (layer_outputs[1],)
362
+
363
+ hidden_states = self.norm(hidden_states)
364
+
365
+ # add hidden states from the last decoder layer
366
+ if output_hidden_states:
367
+ all_hidden_states += (hidden_states,)
368
+
369
+ next_cache = next_decoder_cache if use_cache else None
370
+ if not return_dict:
371
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
372
+ return BaseModelOutputWithPast(
373
+ last_hidden_state=hidden_states,
374
+ past_key_values=next_cache,
375
+ hidden_states=all_hidden_states,
376
+ attentions=all_self_attns,
377
+ )
378
+
379
+
380
+ def causal_model_forward(
381
+ self,
382
+ input_ids: torch.LongTensor = None,
383
+ modality_indicators: torch.Tensor = None,
384
+ attention_mask: Optional[torch.Tensor] = None,
385
+ position_ids: Optional[torch.LongTensor] = None,
386
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
387
+ inputs_embeds: Optional[torch.FloatTensor] = None,
388
+ labels: Optional[torch.LongTensor] = None,
389
+ use_cache: Optional[bool] = None,
390
+ output_attentions: Optional[bool] = None,
391
+ output_hidden_states: Optional[bool] = None,
392
+ return_dict: Optional[bool] = None,
393
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
394
+ r"""
395
+ Args:
396
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
397
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
398
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
399
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
400
+
401
+ Returns:
402
+
403
+ Example:
404
+
405
+ ```python
406
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
407
+
408
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
409
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
410
+
411
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
412
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
413
+
414
+ >>> # Generate
415
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
416
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
417
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
418
+ ```"""
419
+
420
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
421
+ output_hidden_states = (
422
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
423
+ )
424
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
425
+
426
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
427
+ outputs = self.model(
428
+ input_ids=input_ids,
429
+ modality_indicators=modality_indicators,
430
+ attention_mask=attention_mask,
431
+ position_ids=position_ids,
432
+ past_key_values=past_key_values,
433
+ inputs_embeds=inputs_embeds,
434
+ use_cache=use_cache,
435
+ output_attentions=output_attentions,
436
+ output_hidden_states=output_hidden_states,
437
+ return_dict=return_dict,
438
+ )
439
+
440
+ hidden_states = outputs[0]
441
+ if self.config.pretraining_tp > 1:
442
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
443
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
444
+ logits = torch.cat(logits, dim=-1)
445
+ else:
446
+ logits = self.lm_head(hidden_states)
447
+ logits = logits.float()
448
+
449
+ loss = None
450
+ if labels is not None:
451
+ # Shift so that tokens < n predict n
452
+ shift_logits = logits[..., :-1, :].contiguous()
453
+ shift_labels = labels[..., 1:].contiguous()
454
+ # Flatten the tokens
455
+ loss_fct = CrossEntropyLoss()
456
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
457
+ shift_labels = shift_labels.view(-1)
458
+ # Enable model parallelism
459
+ shift_labels = shift_labels.to(shift_logits.device)
460
+ loss = loss_fct(shift_logits, shift_labels)
461
+
462
+ if not return_dict:
463
+ output = (logits,) + outputs[1:]
464
+ return (loss,) + output if loss is not None else output
465
+
466
+ return CausalLMOutputWithPast(
467
+ loss=loss,
468
+ logits=logits,
469
+ past_key_values=outputs.past_key_values,
470
+ hidden_states=outputs.hidden_states,
471
+ attentions=outputs.attentions,
472
+ )
473
+
474
+ def replace_llama_modality_adaptive():
475
+ transformers.models.llama.configuration_llama.LlamaConfig = LlamaConfig
476
+ transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
477
+ transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
478
+ transformers.models.llama.modeling_llama.LlamaModel.forward = model_forward
479
+ transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = causal_model_forward
480
+
481
+
482
+ if __name__ == "__main__":
483
+ replace_llama_modality_adaptive()
484
+ config = transformers.LlamaConfig.from_pretrained('/cpfs01/shared/public/test/vicuna-7b-v1.5/')
485
+ model = transformers.LlamaForCausalLM(config)
486
+ print(model)