TroL / app.py
BK-Lee's picture
v1
3fb84e5
raw
history blame
6.61 kB
# A100 Zero GPU
import spaces
# TroL Package
import torch
from PIL import Image
from utils.utils import *
import torch.nn.functional as F
from trol.load_trol import load_trol
from torchvision.transforms.functional import pil_to_tensor
# Gradio Package
import time
import gradio as gr
from threading import Thread
from accelerate import Accelerator
from transformers import TextIteratorStreamer
from torchvision.transforms.functional import pil_to_tensor
# flash attention
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# accel
accel = Accelerator()
# User prompt
prompt_type="with_image" # Select one option "text_only", "with_image"
img_path='figures/demo.png'
question="What is the troll doing? Provide the detail in the image and imagine what the event happens."
# loading model
model_1_8, tokenizer_1_8 = load_trol(link='TroL-1.8B')
# loading model
model_3_8, tokenizer_3_8 = load_trol(link='TroL-3.8B')
# loading model
model_7, tokenizer_7 = load_trol(link='TroL-7B')
def threading_function(inputs, image_token_number, streamer, device, model, tokenizer, temperature, new_max_token, top_p):
# propagation
_inputs = model.eval_process(inputs=inputs,
data='demo',
tokenizer=tokenizer,
device=device,
img_token_number=image_token_number)
generation_kwargs = _inputs
generation_kwargs.update({'streamer': streamer})
generation_kwargs.update({'do_sample': True})
generation_kwargs.update({'max_new_tokens': new_max_token})
generation_kwargs.update({'top_p': top_p})
generation_kwargs.update({'temperature': temperature})
generation_kwargs.update({'use_cache': True})
return model.generate(**generation_kwargs)
@spaces.GPU
def bot_streaming(message, history, link, temperature, new_max_token, top_p):
# model selection
if "1.8B" in link:
model = model_1_8
tokenizer = tokenizer_1_8
path = "BK-Lee/TroL-1.8B"
elif "3.8B" in link:
model = model_3_8
tokenizer = tokenizer_3_8
path = "BK-Lee/TroL-3.8B"
elif "7B" in link:
model = model_7
tokenizer = tokenizer_7
path = "BK-Lee/TroL-7B"
# trol gating load
from huggingface_hub import hf_hub_download
try:
model.model.initialize_trol_gating()
model.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
except:
model.language_model.model.initialize_trol_gating()
model.language_model.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
# X -> float16 conversion
for param in model.parameters():
if 'float32' in str(param.dtype).lower() or 'float16' in str(param.dtype).lower():
param.data = param.data.to(torch.float16)
# cpu -> gpu
for param in model.parameters():
if not param.is_cuda:
param.data = param.to(accel.device)
try:
# prompt type -> input prompt
image_token_number = None
if len(message['files']) == 1:
# Image Load
image = pil_to_tensor(Image.open(message['files'][0]).convert("RGB"))
if "3.8B" not in link:
image_token_number = 1225
image = F.interpolate(image.unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0)
inputs = [{'image': image.to(accel.device), 'question': message['text']}]
elif len(message['files']) > 1:
raise Exception("No way!")
else:
inputs = [{'question': message['text']}]
# Text Generation
with torch.inference_mode():
# kwargs
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
# Threading generation
thread = Thread(target=threading_function, kwargs=dict(inputs=inputs,
image_token_number=image_token_number,
streamer=streamer,
model=model,
tokenizer=tokenizer,
device=accel.device,
temperature=temperature,
new_max_token=new_max_token,
top_p=top_p))
thread.start()
# generated text
generated_text = ""
for new_text in streamer:
generated_text += new_text
generated_text
# Text decoding
response = output_filtering(generated_text, model)
except:
response = "There may be unsupported format: ex) pdf, video, sound. Only supported is a single image in this version."
# private log print
text = message['text']
files = message['files']
print('-----------------------------')
print(f'Link: {link}')
print(f'Text: {text}')
print(f'MM Files: {files}')
print(f'Response: {response}')
print('-----------------------------\n')
buffer = ""
for character in response:
buffer += character
time.sleep(0.012)
yield buffer
demo = gr.ChatInterface(fn=bot_streaming,
additional_inputs = [gr.Radio(["1.8B", "3.8B", "7B"], label="Size", info="Select one model size", value="7B"), gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")],
additional_inputs_accordion="Generation Hyperparameters",
theme=gr.themes.Soft(),
title="TroL",
description="TroL is efficient 1.8B, 3.8B, and 7B size Large Language and Vision Models built on new propagation strategy. "
"Its inference speed highly depends on assinging non-scheduled GPU. (Therefore, once all GPUs are busy, then inference may be taken in infinity) "
"Note that, we don't support history-based conversation referring to previous dialogue",
stop_btn="Stop Generation", multimodal=True)
demo.launch()