import argparse import gradio as gr import torch from model import IntentPredictModel from transformers import (T5Tokenizer, GPT2Tokenizer, GPT2Config, GPT2LMHeadModel) from diffusers import StableDiffusionPipeline from chatbot import Chat def main(args): # Intent Prediction print("Loading Intent Prediction Classifier...") ## tokenizer intent_predict_tokenizer = T5Tokenizer.from_pretrained(args.intent_predict_model_name, truncation_side='left') intent_predict_tokenizer.add_special_tokens({'sep_token': '[SEP]'}) # model intent_predict_model = IntentPredictModel(pretrained_model_name_or_path=args.intent_predict_model_name, num_classes=2) intent_predict_model.load_state_dict(torch.load(args.intent_predict_model_weights_path, map_location=args.device)) print("Intent Prediction Classifier loading completed.") # Textual Dialogue Response Generator print("Loading Textual Dialogue Response Generator...") ## tokenizer text_dialog_tokenizer = GPT2Tokenizer.from_pretrained(args.text_dialog_model_name, truncation_side='left') text_dialog_tokenizer.add_tokens(['[UTT]', '[DST]']) print(len(text_dialog_tokenizer)) # config text_dialog_config = GPT2Config.from_pretrained(args.text_dialog_model_name) if len(text_dialog_tokenizer) > text_dialog_config.vocab_size: text_dialog_config.vocab_size = len(text_dialog_tokenizer) # load model weights text_dialog_model = GPT2LMHeadModel.from_pretrained(args.text_dialog_model_weights_path, config=text_dialog_config) print("Textual Dialogue Response Generator loading completed.") # Text-to-Image Translator print("Loading Text-to-Image Translator...") text2image_model = StableDiffusionPipeline.from_pretrained(args.text2image_model_weights_path, torch_dtype=torch.float32) print("Text-to-Image Translator loading completed.") chat = Chat(intent_predict_model, intent_predict_tokenizer, text_dialog_model, text_dialog_tokenizer, text2image_model, args.device) title = """

Demo of Tiger

""" description1 = """

This is the demo of Tiger (Generative Multimodal Dialogue Model).

""" description2 = """

Input text start chatting!

""" description_input = """

Input: text

""" description_output = """

Output: text / image

""" with gr.Blocks() as demo: gr.Markdown(title) gr.Markdown(description1) gr.Markdown(description2) gr.Markdown(description_input) gr.Markdown(description_output) with gr.Row(): with gr.Column(scale=0.33): num_beams = gr.Slider( minimum=1, maximum=10, value=5, step=1, interactive=True, label="beam search numbers", ) text2image_seed = gr.Slider( minimum=1, maximum=100, value=42, step=1, interactive=True, label="seed for text-to-image", ) clear = gr.Button("Restart (Clear dialogue history)") with gr.Column(): chat_state = gr.State() chatbot = gr.Chatbot(label='Tiger') text_input = gr.Textbox(label='User', placeholder='Please input the text.') text_input.submit(chat.respond, [text_input, num_beams, text2image_seed, chatbot, chat_state], [text_input, chatbot, chat_state]) clear.click(lambda: None, None, chatbot, queue=False) demo.launch(share=True, enable_queue=True, server_name="219.216.64.177") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--intent_predict_model_name', type=str, default="t5-base") parser.add_argument('--intent_predict_model_weights_path', type=str, default="model_weights/Tiger_t5_base_encoder.pth") parser.add_argument('--text_dialog_model_name', type=str, default="microsoft/DialoGPT-medium") parser.add_argument('--text_dialog_model_weights_path', type=str, default="model_weights/Tiger_DialoGPT_medium.pth") parser.add_argument('--text2image_model_weights_path', type=str, default="model_weights/stable-diffusion-2-1-realistic") parser.add_argument('--device', default="cuda:6") args = parser.parse_args() main(args)