ProPerNounpYK commited on
Commit
da6626c
·
verified ·
1 Parent(s): 0d7e301

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -25
app.py CHANGED
@@ -1,34 +1,18 @@
1
- import torch
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
- from PIL import Image
4
  import gradio as gr
 
5
 
6
- # Load text-to-image model
7
- text_to_image_model = torch.hub.load("ProPerNounpYK/texttoimage", "generator", map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
8
 
9
- # Load chat model
10
- chat_model = AutoModelForSeq2SeqLM.from_pretrained("ProPerNounpYK/chat")
11
- chat_tokenizer = AutoTokenizer.from_pretrained("ProPerNounpYK/chat")
12
 
13
- # Create multimodal interface
14
  interface = gr.Interface(
15
- fn=lambda input_text, input_image: generate_response(input_text, input_image),
16
- inputs=["text", "image"],
17
  outputs=["text", "image"],
18
- title="Multimodal Conversational AI",
19
- description="Talk to me, and I'll respond with images!"
20
  )
21
 
22
- def generate_response(input_text, input_image):
23
- # Process input text using chat model
24
- chat_output = chat_model(input_text)
25
- chat_response = chat_tokenizer.decode(chat_output, skip_special_tokens=True)
26
-
27
- # Process input image using text-to-image model
28
- generated_image = text_to_image_model(input_text, input_image)
29
-
30
- # Return response as a tuple of text and image
31
- return chat_response, generated_image
32
-
33
- # Launch the interface
34
  interface.launch()
 
 
 
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
 
4
+ # Text-to-Image model
5
+ text_to_image = pipeline("text-to-image", model="ProPerNounpYK/texttoimage")
6
 
7
+ # Chat model
8
+ chat = pipeline("conversational", model="ProPerNounpYK/chat")
 
9
 
 
10
  interface = gr.Interface(
11
+ fn=lambda input: chat(input, text_to_image(input)),
12
+ inputs="text",
13
  outputs=["text", "image"],
14
+ title="Text-to-Image Chat",
15
+ description="Type something and get a response with an image!"
16
  )
17
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  interface.launch()