Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,704 Bytes
f04732f a3db70a c36d5bb a3db70a f04732f 08bc998 a3db70a c36d5bb a3db70a c36d5bb a3db70a c36d5bb a3db70a 4be8019 a3db70a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteria
from threading import Thread
import re
import time
from PIL import Image
import torch
import spaces
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
torch.set_default_device('cuda')
tokenizer = AutoTokenizer.from_pretrained(
'qnguyen3/nanoLLaVA',
trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
'qnguyen3/nanoLLaVA',
torch_dtype=torch.float16,
device_map='auto',
trust_remote_code=True)
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = []
self.max_keyword_len = 0
for keyword in keywords:
cur_keyword_ids = tokenizer(keyword).input_ids
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
cur_keyword_ids = cur_keyword_ids[1:]
if len(cur_keyword_ids) > self.max_keyword_len:
self.max_keyword_len = len(cur_keyword_ids)
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
self.tokenizer = tokenizer
self.start_len = input_ids.shape[1]
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
for keyword_id in self.keyword_ids:
truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
if torch.equal(truncated_output_ids, keyword_id):
return True
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
outputs = []
for i in range(output_ids.shape[0]):
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
return all(outputs)
@spaces.GPU
def bot_streaming(message, history):
messages = []
if message["files"]:
image = message["files"][-1]["path"]
else:
for i, hist in enumerate(history):
if type(hist[0])==tuple:
image = hist[0][0]
image_turn = i
if len(history) > 0 and image is not None:
messages.append({"role": "user", "content": f'<image>\n{history[1][0]}'})
messages.append({"role": "assistant", "content": history[1][1] })
for human, assistant in history[2:]:
messages.append({"role": "user", "content": human })
messages.append({"role": "assistant", "content": assistant })
messages.append({"role": "user", "content": message['text']})
elif len(history) > 0 and image is None:
for human, assistant in history:
messages.append({"role": "user", "content": human })
messages.append({"role": "assistant", "content": assistant })
messages.append({"role": "user", "content": message['text']})
elif len(history) == 0 and image is not None:
messages.append({"role": "user", "content": f"<image>\n{message['text']}"})
elif len(history) == 0 and image is None:
messages.append({"role": "user", "content": message['text'] })
# if image is None:
# gr.Error("You need to upload an image for LLaVA to work.")
image = Image.open(image).convert("RGB")
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True)
text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
stop_str = '<|im_end|>'
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
generation_kwargs = dict(input_ids=input_ids, images=image_tensor, streamer=streamer, max_new_tokens=100, stopping_criteria=[stopping_criteria])
generated_text = ""
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
text_prompt =f"<|im_start|>user\n{message['text']}<|im_end|>"
buffer = ""
for new_text in streamer:
buffer += new_text
generated_text_without_prompt = buffer[len(text_prompt):]
time.sleep(0.04)
yield generated_text_without_prompt
demo = gr.ChatInterface(fn=bot_streaming, title="LLaVA NeXT", examples=[{"text": "What is on the flower?", "files":["./bee.jpg"]},
{"text": "How to make this pastry?", "files":["./baklava.png"]}],
description="Try [LLaVA NeXT](https://huggingface.co/docs/transformers/main/en/model_doc/llava_next) in this demo (more specifically, the [Mistral-7B variant](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf)). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
stop_btn="Stop Generation", multimodal=True)
demo.launch(debug=True) |