Spaces:
Runtime error
Runtime error
from PIL import Image | |
import torch | |
from transformers import StoppingCriteria, StoppingCriteriaList | |
from enum import auto, Enum | |
import numpy as np | |
from decord import VideoReader, cpu | |
import torchvision.transforms as T | |
from models.video_transformers import ( | |
GroupNormalize, GroupScale, GroupCenterCrop, | |
Stack, ToTorchFormatTensor | |
) | |
from torchvision.transforms.functional import InterpolationMode | |
from transformers import LlamaTokenizer, LlamaConfig | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
class SeparatorStyle(Enum): | |
"""Different separator style.""" | |
SINGLE = auto() | |
TWO = auto() | |
def get_prompt(conv): | |
ret = conv.system + conv.sep | |
for role, message in conv.messages: | |
if message: | |
ret += role + ": " + message + conv.sep | |
else: | |
ret += role + ":" | |
return ret | |
class StoppingCriteriaSub(StoppingCriteria): | |
def __init__(self, stops=[], encounters=1): | |
super().__init__() | |
self.stops = stops | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): | |
for stop in self.stops: | |
if torch.all((stop == input_ids[0][-len(stop):])).item(): | |
return True | |
return False | |
class Chat: | |
def __init__(self, model, device='cuda:0'): | |
self.device = device | |
self.model = model | |
stop_words_ids = [torch.tensor([835]).to(self.device), | |
torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways. | |
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) | |
def ask(self,text,conv): | |
conv.messages.append([conv.roles[0], text + '\n']) | |
return conv | |
def answer(self, conv, img_list, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9, | |
repetition_penalty=1.0, length_penalty=1, temperature=1.0): | |
conv.messages.append([conv.roles[1], None]) | |
with torch.no_grad(): | |
embs = self.get_context_emb(conv, img_list) | |
outputs = self.model.llama_model.generate( | |
inputs_embeds=embs, | |
max_new_tokens=max_new_tokens, | |
stopping_criteria=self.stopping_criteria, | |
num_beams=num_beams, | |
do_sample=True, | |
min_length=min_length, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
length_penalty=length_penalty, | |
temperature=temperature, | |
) | |
output_token = outputs[0] | |
if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it | |
output_token = output_token[1:] | |
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it | |
output_token = output_token[1:] | |
output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False) | |
output_text = output_text.split('###')[0] # remove the stop sign '###' | |
output_text = output_text.split('Assistant:')[-1].strip() | |
conv.messages[-1][1] = output_text | |
return output_text, output_token.cpu().numpy(), conv | |
def get_index(self, num_frames, num_segments): | |
seg_size = float(num_frames - 1) / num_segments | |
start = int(seg_size / 2) | |
offsets = np.array([ | |
start + int(np.round(seg_size * idx)) for idx in range(num_segments) | |
]) | |
return offsets | |
def load_video(self, video_path, num_segments=8, return_msg=False): | |
vr = VideoReader(video_path, ctx=cpu(0)) | |
num_frames = len(vr) | |
frame_indices = self.get_index(num_frames, num_segments) | |
duration = len(vr) // vr.get_avg_fps() | |
index = np.linspace(0, len(vr)-1, num=int(duration)) | |
buffer = vr.get_batch(index).asnumpy() | |
# transform | |
input_mean = [0.48145466, 0.4578275, 0.40821073] | |
input_std = [0.26862954, 0.26130258, 0.27577711] | |
transform = T.Compose([ | |
GroupScale(int(224), interpolation=InterpolationMode.BICUBIC), | |
GroupCenterCrop(224), | |
Stack(), | |
ToTorchFormatTensor(), | |
GroupNormalize(input_mean, input_std) | |
]) | |
images_group = list() | |
for frame in buffer: | |
img = Image.fromarray(frame) | |
images_group.append(img) | |
images_group = list() | |
for frame_index in frame_indices: | |
img = Image.fromarray(vr[frame_index].asnumpy()) | |
images_group.append(img) | |
torch_imgs_224 = transform(images_group) | |
if return_msg: | |
fps = float(vr.get_avg_fps()) | |
sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices]) | |
# " " should be added in the start and end | |
msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds." | |
return torch_imgs_224, msg | |
else: | |
return torch_imgs_224 | |
def upload_video(self, image, conv, img_list, num_segments): | |
if isinstance(image, str): # is a image path | |
vid_chat, msg = self.load_video(image, num_segments=num_segments, return_msg=True) | |
TC, H, W = vid_chat.shape | |
image = vid_chat.reshape(1, TC//3, 3, H, W).to(self.device) | |
else: | |
raise NotImplementedError | |
with torch.no_grad(): | |
print("Input video shape:", vid_chat.shape) | |
image_emb, _ = self.model.encode_img(image) | |
img_list.append(image_emb) | |
conv.messages.append([ | |
conv.roles[0], | |
f"<Video><VideoHere></Video> {msg}\n" | |
]) | |
msg = "Received." | |
# self.conv.append_message(self.conv.roles[1], msg) | |
return msg, img_list, conv | |
def upload_img(self, image, conv, img_list): | |
img = image#Image.open(image)#.convert('RGB') | |
transform = T.Compose( | |
[ | |
T.Resize( | |
(224, 224), interpolation=InterpolationMode.BICUBIC | |
), | |
T.ToTensor(), | |
T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), | |
] | |
) | |
with torch.no_grad(): | |
img = transform(img).unsqueeze(0).unsqueeze(0).cuda() | |
image_emb, _ = self.model.encode_img(img) | |
img_list.append(image_emb) | |
conv.messages.append([ | |
conv.roles[0], | |
f"<Image><ImageHere></Image>\n" | |
]) | |
msg = "Received." | |
# self.conv.append_message(self.conv.roles[1], msg) | |
return msg,img_list, conv | |
def get_context_emb(self, conv, img_list): | |
prompt = get_prompt(conv) | |
#print(prompt) | |
if '<VideoHere>' in prompt: | |
prompt_segs = prompt.split('<VideoHere>') | |
else: | |
prompt_segs = prompt.split('<ImageHere>') | |
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of visual placeholders and videos." | |
seg_tokens = [ | |
self.model.llama_tokenizer( | |
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids | |
# only add bos to the first seg | |
for i, seg in enumerate(prompt_segs) | |
] | |
seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens] | |
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] | |
mixed_embs = torch.cat(mixed_embs, dim=1) | |
return mixed_embs | |