kadirnar commited on
Commit
62a3875
1 Parent(s): 241978f

Create modeling_llamavision.py

Browse files
Files changed (1) hide show
  1. modeling_llamavision.py +148 -0
modeling_llamavision.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import (
4
+ PreTrainedModel,
5
+ AutoModelForCausalLM,
6
+ AutoModel,
7
+ SiglipImageProcessor,
8
+ )
9
+ from .configuration_llamavision import LlamavisionConfig
10
+
11
+
12
+ class ProjectionModule(nn.Module):
13
+ def __init__(self, mm_hidden_size=1152, hidden_size=4096):
14
+ super(ProjectionModule, self).__init__()
15
+
16
+ # Directly set up the sequential model
17
+ self.model = nn.Sequential(
18
+ nn.Linear(mm_hidden_size, hidden_size),
19
+ nn.GELU(),
20
+ nn.Linear(hidden_size, hidden_size),
21
+ )
22
+
23
+ def forward(self, x):
24
+ return self.model(x)
25
+
26
+
27
+ class Llamavision(PreTrainedModel):
28
+ config_class = LlamavisionConfig
29
+
30
+ def __init__(self, config):
31
+ super().__init__(config)
32
+
33
+ self.vision_model = AutoModel.from_config(self.config.vision_config)
34
+ self.text_model = AutoModelForCausalLM.from_config(self.config.text_config)
35
+ self.processor = SiglipImageProcessor()
36
+ self.mm_projector = ProjectionModule(
37
+ mm_hidden_size=config.vision_config.hidden_size,
38
+ hidden_size=config.text_config.hidden_size,
39
+ )
40
+
41
+ @property
42
+ def device(self):
43
+ return self.text_model.device
44
+
45
+ def encode_image(self, image):
46
+ image = image.convert("RGB")
47
+ image = self.processor(
48
+ images=image,
49
+ return_tensors="pt",
50
+ do_resize=True,
51
+ size={"height": 378, "width": 378},
52
+ )["pixel_values"].to(
53
+ device=self.vision_model.device, dtype=self.vision_model.dtype
54
+ )
55
+ with torch.no_grad():
56
+ return self.vision_model(image, output_hidden_states=True).hidden_states[-2]
57
+
58
+ def input_embeds(self, prompt, image_embeds, tokenizer):
59
+ def _tokenize(txt):
60
+ return tokenizer(
61
+ txt, return_tensors="pt", add_special_tokens=False
62
+ ).input_ids.to(self.device)
63
+
64
+ text_emb = self.text_model.get_input_embeddings()
65
+
66
+ embeds = []
67
+
68
+ tokenized_prompt = _tokenize(prompt)
69
+ if (
70
+ tokenizer.bos_token_id is not None
71
+ and tokenized_prompt[0][0] != tokenizer.bos_token_id
72
+ ):
73
+ embeds.append(
74
+ text_emb(torch.tensor([[tokenizer.bos_token_id]], device=self.device))
75
+ )
76
+
77
+ projected_image_embeds = self.mm_projector(image_embeds.to(self.device))
78
+ embeds.append(projected_image_embeds)
79
+
80
+ embeds.append(text_emb(tokenized_prompt))
81
+
82
+ return torch.cat(embeds, dim=1)
83
+
84
+ def get_input_embeddings(self):
85
+ return self.text_model.get_input_embeddings()
86
+
87
+ def generate(
88
+ self,
89
+ image_embeds,
90
+ prompt,
91
+ tokenizer,
92
+ max_new_tokens=128,
93
+ **kwargs,
94
+ ):
95
+ generate_config = {
96
+ "eos_token_id": [
97
+ tokenizer.eos_token_id,
98
+ tokenizer.convert_tokens_to_ids("<|eot_id|>"),
99
+ ],
100
+ "bos_token_id": tokenizer.bos_token_id,
101
+ "pad_token_id": tokenizer.pad_token_id,
102
+ "max_new_tokens": max_new_tokens,
103
+ **kwargs,
104
+ }
105
+
106
+ with torch.no_grad():
107
+ inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
108
+
109
+ attention_mask = torch.ones(
110
+ inputs_embeds.shape[:2],
111
+ dtype=torch.long,
112
+ device=inputs_embeds.device
113
+ )
114
+
115
+ output_ids = self.text_model.generate(
116
+ inputs_embeds=inputs_embeds,
117
+ attention_mask=attention_mask,
118
+ **generate_config
119
+ )
120
+
121
+ return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
122
+
123
+ def answer_question(self, image, question, tokenizer, **kwargs):
124
+ image_embeds = self.encode_image(image)
125
+
126
+ chat = [
127
+ {
128
+ "role": "system",
129
+ "content": "You are a helpful AI assistant that can see images and answer questions about them.",
130
+ },
131
+ {"role": "user", "content": question},
132
+ ]
133
+ prompt = tokenizer.apply_chat_template(
134
+ chat, tokenize=False, add_generation_prompt=True
135
+ )
136
+
137
+ # Generate the answer
138
+ with torch.no_grad():
139
+ output = self.generate(
140
+ image_embeds=image_embeds,
141
+ prompt=prompt,
142
+ tokenizer=tokenizer,
143
+ **kwargs,
144
+ )[0]
145
+
146
+ # Clean and return the answer
147
+ cleaned_answer = output.strip()
148
+ return cleaned_answer