Andyson's picture
huggingface hub
0447610
raw
history blame
15.8 kB
import torch
import torch.nn as nn
from transformers import LlamaForCausalLM, LlamaConfig
from transformers import LogitsProcessor, LogitsProcessorList
from transformers import AutoModel
from .generation import AutoImageTokenGenerationProcessor
import torch.nn.functional as F
BOI_TOKEN = '<img>'
EOI_TOKEN = '</img>'
IMG_TOKEN = '<img_{:05d}>'
def cosine_loss(rec, target):
target = target / target.norm(dim=-1, keepdim=True)
rec = rec / rec.norm(dim=-1, keepdim=True)
rec_loss = (1 - (target * rec).sum(-1)).mean()
return rec_loss
class ContinuousLVLM(nn.Module):
def __init__(self, llm, input_resampler, output_resampler, lm_loss_scale=1.0, rec_loss_scale=1.0) -> None:
super().__init__()
self.llm = llm
self.input_resampler = input_resampler
self.output_resampler = output_resampler
self.lm_loss_scale = lm_loss_scale
self.rec_loss_scale = rec_loss_scale
# input_resampler.requires_grad_(False)
# output_resampler.requires_grad_(False)
def forward(self, input_ids, attention_mask, labels, image_embeds, embeds_gen_mask, embeds_cmp_mask, ids_gen_mask,
ids_cmp_mask, return_recon_image_embeds=False):
input_embeds = self.llm.get_input_embeddings()(input_ids) # bz x seq_len x dim, 4 x 160 x 4096
bz, sq, dim = input_embeds.shape
if image_embeds is not None:
image_embeds_lm = self.input_resampler(image_embeds) # num_imgs_in_batch x nq x dim, 4 x 64 x 4096
has_image = True
else:
image_embeds = torch.randn(bz, self.output_resampler.num_queries,
self.output_resampler.embed_dim).to(input_embeds.device,
dtype=input_embeds.dtype)
image_embeds_lm = self.input_resampler(image_embeds)
has_image = False
has_image_input = has_image and embeds_cmp_mask.sum().item() > 0
has_image_output = has_image and embeds_gen_mask.sum().item() > 0
if has_image_input:
input_embeds[ids_cmp_mask] = image_embeds_lm[embeds_cmp_mask].view(-1, dim) # eg, 128 x 4096
# zero_loss = 0.0
else:
min_bz = min(input_embeds.shape[0], image_embeds_lm.shape[0])
input_embeds[:min_bz, :self.input_resampler.
num_queries, :] = input_embeds[:min_bz, :self.input_resampler.
num_queries, :] + 0.0 * image_embeds_lm[:min_bz, :, :]
output_lm = self.llm(attention_mask=attention_mask,
inputs_embeds=input_embeds,
labels=labels,
output_hidden_states=True,
return_dict=True)
lm_loss = output_lm['loss']
last_hidden_state = output_lm.hidden_states[-1] # 4 x 160 x 4096
if has_image_output:
target_embeds = image_embeds[embeds_gen_mask] # num_imgs_gen_target x nq_in x dim_in, 2 x 256 x 4096
num_imgs_for_rec = target_embeds.shape[0]
output_image_embeds = last_hidden_state[ids_gen_mask].view(num_imgs_for_rec, -1,
dim) # 128 x 4096 -> 2 x 64 x 4096
recon_image_embeds = self.output_resampler(output_image_embeds) # 2 x 256 x 4096
rec_loss = cosine_loss(recon_image_embeds, target_embeds)
else:
output_image_embeds = torch.randn(bz, self.input_resampler.num_queries,
self.input_resampler.embed_dim).to(input_embeds.device,
dtype=input_embeds.dtype)
recon_image_embeds = self.output_resampler(output_image_embeds)
target_embeds = torch.randn(bz, self.output_resampler.num_queries,
self.output_resampler.embed_dim).to(input_embeds.device,
dtype=input_embeds.dtype)
rec_loss = cosine_loss(recon_image_embeds, target_embeds) * 0.0
total_loss = self.lm_loss_scale * lm_loss + self.rec_loss_scale * rec_loss
if return_recon_image_embeds and has_image_output:
return {'total_loss': total_loss, 'lm_loss': lm_loss, 'rec_loss': rec_loss,
'recon_image_embeds': recon_image_embeds}
else:
return {'total_loss': total_loss, 'lm_loss': lm_loss, 'rec_loss': rec_loss}
def generate(self,
tokenizer,
prompt=None,
input_ids=None,
image_embeds=None,
embeds_cmp_mask=None,
ids_cmp_mask=None,
logits_processor=None,
num_img_gen_tokens=64,
temperature=0.7,
num_beams=1,
max_new_tokens=120,
top_p=0.5,
past_key_values=None,
# position_ids=None,
dtype=torch.float16,
device='cuda'):
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(
AutoImageTokenGenerationProcessor(tokenizer=tokenizer, num_img_gen_tokens=num_img_gen_tokens))
if prompt is not None:
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
if isinstance(input_ids, list):
input_ids = torch.tensor(input_ids)
input_ids = input_ids.to(device=device)
input_embeds = self.llm.get_input_embeddings()(input_ids)
bz, sq, dim = input_embeds.shape
if image_embeds is not None:
assert embeds_cmp_mask is not None and ids_cmp_mask is not None
with torch.no_grad():
image_embeds_lm = self.input_resampler(image_embeds)
input_embeds[ids_cmp_mask] = image_embeds_lm[embeds_cmp_mask].view(-1, dim)
generation_config = {
'temperature': temperature,
'num_beams': num_beams,
'max_new_tokens': max_new_tokens,
'top_p': top_p,
'do_sample': False
}
# generate_ids = self.llm.generate(input_ids=input_ids, **generation_config)
output = self.llm.generate(input_ids=input_ids,
inputs_embeds=input_embeds,
output_hidden_states=True,
return_dict_in_generate=True,
logits_processor=logits_processor,
past_key_values=past_key_values,
# position_ids=position_ids,
**generation_config)
# self.llm.base_model.model.position_ids = self.llm.base_model.model.position_ids[:, :-2]
output_past_key_values = self.llm.past_key_values
generate_ids = output.sequences[0][input_ids.shape[1]:]
generate_id_list = generate_ids.tolist()
boi_token_id = tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
eoi_token_id = tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
attn_weights = ()
def merge_attn_weights(attn_weights):
merged_attn_weights = attn_weights[0]
# Iterate through the remaining attention weight tensors
for i, attn_weight in enumerate(attn_weights[1:]):
merged_attn_weights = F.pad(merged_attn_weights, (0, 1), "constant", float('nan'))
# Concatenate the expanded tensor to the merged tensor along the kv_len dimension
merged_attn_weights = torch.cat([merged_attn_weights, attn_weight], dim=1)
return merged_attn_weights
if output.attentions is not None:
# for idx in [0, 1, 2, 9, 16, 23, 31]:
for idx in range(32):
attn_weights += (
merge_attn_weights([output.attentions[j][idx] for j in range(len(output.attentions))]),)
# for skip image multi turn kvcache
last_hidden_states = torch.cat([hidden_state[-1] for hidden_state in output.hidden_states], dim=1)
if past_key_values is None:
last_hidden_states = last_hidden_states[0, input_ids.shape[1]:, :]
eoi_indices = torch.where(generate_ids == eoi_token_id)[0].tolist()
else:
last_hidden_states = last_hidden_states[0, :, :]
hidden_len = last_hidden_states.shape[0]
eoi_indices = torch.where(output.sequences[0][-hidden_len:] == eoi_token_id)[0].tolist()
num_gen_imgs = 1 if len(eoi_indices) > 0 else 0
text_mask = torch.ones_like(generate_ids, dtype=torch.bool)
has_img_output = num_gen_imgs > 0
if has_img_output:
img_gen_feats = []
img_gen_feats.append(last_hidden_states[eoi_indices[-1] - num_img_gen_tokens:eoi_indices[-1]])
text_mask[eoi_indices[-1] - num_img_gen_tokens:eoi_indices[-1]] = False
# for eoi_idx in eoi_indices:
# img_gen_feats.append(last_hidden_states[eoi_idx - num_img_gen_tokens:eoi_idx])
# text_mask[eoi_idx - num_img_gen_tokens:eoi_idx] = False
img_gen_feats = torch.stack(img_gen_feats)
img_gen_feat = self.output_resampler(img_gen_feats)
else:
img_gen_feat = None
text_mask[generate_ids == boi_token_id] = False
# generate_ids = generate_ids[text_mask]
generate_text = tokenizer.decode(generate_ids, skip_special_tokens=False)
return {
'text': generate_text,
'generate_ids': generate_ids,
'has_img_output': has_img_output,
'img_gen_feat': img_gen_feat,
'num_gen_imgs': num_gen_imgs,
'attn_weights': attn_weights,
'past_key_values': output_past_key_values
}
@classmethod
def from_pretrained(cls, llm, input_resampler, output_resampler, pretrained_model_path=None, **kwargs):
model = cls(llm=llm, input_resampler=input_resampler, output_resampler=output_resampler, **kwargs)
if pretrained_model_path is not None:
# Check if the path is intended for Hugging Face Hub
if 'TencentARC/SEED-Story' in pretrained_model_path:
# Load from a specific subfolder within the Hugging Face repository
ckpt = AutoModel.from_pretrained(pretrained_model_path, subfolder="seed_story/george_sft")
missing, unexpected = model.load_state_dict(ckpt.state_dict(), strict=False)
print('Agent model, missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
else:
# For local path loading
ckpt = torch.load(pretrained_model_path, map_location='cpu')
missing, unexpected = model.load_state_dict(ckpt, strict=False)
print('Agent model, missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
return model
class SEEDLLaMAAlignGeneration(nn.Module):
def __init__(self, llm, output_resampler) -> None:
super().__init__()
self.llm = llm
self.output_resampler = output_resampler
# self.rec_loss_scale = rec_loss_scale
self.llm.requires_grad_(False)
def forward(self, input_ids, attention_mask, labels, image_embeds, embeds_gen_mask, embeds_cmp_mask, ids_gen_mask,
ids_cmp_mask):
input_embeds = self.llm.get_input_embeddings()(input_ids) # bz x seq_len x dim, 4 x 160 x 4096
bz, sq, dim = input_embeds.shape
output_lm = self.llm(attention_mask=attention_mask,
inputs_embeds=input_embeds,
labels=labels,
output_hidden_states=True,
return_dict=True)
last_hidden_state = output_lm.hidden_states[-1] # 4 x 160 x 4096
target_embeds = image_embeds[embeds_gen_mask] # num_imgs_gen_target x nq_in x dim_in, 2 x 256 x 4096
num_imgs_for_rec = target_embeds.shape[0]
output_image_embeds = last_hidden_state[ids_gen_mask].view(num_imgs_for_rec, -1,
dim) # 128 x 4096 -> 2 x 64 x 4096
recon_image_embeds = self.output_resampler(output_image_embeds) # 2 x 256 x 4096
rec_loss = cosine_loss(recon_image_embeds, target_embeds)
return {'total_loss': rec_loss, 'rec_loss': rec_loss}
@classmethod
def from_pretrained(cls, llm, output_resampler, pretrained_model_path=None, **kwargs):
model = cls(llm=llm, output_resampler=output_resampler, **kwargs)
if pretrained_model_path is not None:
ckpt = torch.load(pretrained_model_path, map_location='cpu')
missing, unexpected = model.load_state_dict(ckpt, strict=False)
print('agent model, missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
return model
def generate(self,
tokenizer,
input_ids=None,
temperature=0.7,
num_beams=1,
max_new_tokens=120,
num_img_gen_tokens=64,
top_p=0.5,
dtype=torch.float16,
device='cuda'):
input_ids = input_ids.to(device=device)
input_embeds = self.llm.get_input_embeddings()(input_ids) # bz x seq_len x dim, 4 x 160 x 4096
generation_config = {
'temperature': temperature,
'num_beams': num_beams,
'max_new_tokens': max_new_tokens,
'top_p': top_p,
'do_sample': False
}
output = self.llm.generate(input_ids=input_ids,
inputs_embeds=input_embeds,
output_hidden_states=True,
return_dict_in_generate=True,
**generation_config)
generate_ids = output.sequences[0][input_ids.shape[1]:]
generate_id_list = generate_ids.tolist()
# boi_token_id = tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
eoi_token_id = tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
# print('output ids: ', generate_ids, generate_ids.shape)
# last_hidden_states = output.hidden_states[-1]
last_hidden_states = torch.cat([hidden_state[-1] for hidden_state in output.hidden_states],
dim=1)[:1, input_ids.shape[1]:, :]
has_img_output = eoi_token_id in generate_id_list
if has_img_output:
# print(boi_token_id, generate_id_list, generate_id_list.index(boi_token_id))
# boi_idx = generate_id_list.index(boi_token_id)
eoi_idx = generate_id_list.index(eoi_token_id)
print(len(generate_id_list), generate_id_list, eoi_idx)
# print(generate_id_list[boi_idx + 1:boi_idx + 1 + num_img_gen_tokens])
# img_gen_feat = last_hidden_states[:, eoi_idx - num_img_gen_tokens:eoi_idx]
img_gen_feat = last_hidden_states[:, 0:eoi_idx]
print('img_gen_feat', img_gen_feat.shape, last_hidden_states.shape, num_img_gen_tokens)
img_gen_feat = self.output_resampler(img_gen_feat)
else:
img_gen_feat = None
generate_text = tokenizer.decode(generate_ids, skip_special_tokens=False)
# print('output keys: ', output.keys())
return {'text': generate_text, 'has_img_output': has_img_output, 'img_gen_feat': img_gen_feat}