art / app.py
ProPerNounpYK's picture
Update app.py
e946bb1 verified
raw
history blame
1.23 kB
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from PIL import Image
import gradio as gr
# Load text-to-image model
text_to_image_model = torch.hub.load("ProPerNounpYK/texttoimage", "generator", map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
# Load chat model
chat_model = AutoModelForSeq2SeqLM.from_pretrained("ProPerNounpYK/chat")
chat_tokenizer = AutoTokenizer.from_pretrained("ProPerNounpYK/chat")
# Create multimodal interface
interface = gr.Interface(
fn=lambda input_text, input_image: generate_response(input_text, input_image),
inputs=["text", "image"],
outputs=["text", "image"],
title="Multimodal Conversational AI",
description="Talk to me, and I'll respond with images!"
)
def generate_response(input_text, input_image):
# Process input text using chat model
chat_output = chat_model(input_text)
chat_response = chat_tokenizer.decode(chat_output, skip_special_tokens=True)
# Process input image using text-to-image model
generated_image = text_to_image_model(input_text, input_image)
# Return response as a tuple of text and image
return chat_response, generated_image
# Launch the interface
interface.launch()