File size: 1,674 Bytes
53d6474
eb09c16
 
cad1126
 
 
87bd002
1674572
6f8418a
87bd002
6f8418a
 
 
87bd002
6f8418a
eb09c16
6f8418a
 
 
2471c01
6f8418a
87bd002
1674572
6f8418a
 
87bd002
6f8418a
 
87bd002
6f8418a
 
 
 
 
 
f718f04
1674572
f718f04
2471c01
1674572
 
f718f04
 
 
0930360
 
f718f04
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import torch
import theme

theme = theme.Theme()

# Cell 1: Image Classification Model
image_pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")

def predict_image(input_img):
    predictions = image_pipeline(input_img)
    return input_img, {p["label"]: p["score"] for p in predictions} 

image_gradio_app = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(label="Select hot dog candidate", sources=['upload', 'webcam'], type="pil"),
    outputs=[gr.Image(label="Processed Image"), gr.Label(label="Result", num_top_classes=2)],
    title="Hot Dog? Or Not?",
    theme=theme
)

# Cell 2: Chatbot Model
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
chatbot_model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")

def predict_chatbot(input, history=[]):
    new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
    bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
    history = chatbot_model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
    response = tokenizer.decode(history[0]).split("")

    response_tuples = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)]
    return response_tuples, history

chatbot_gradio_app = gr.ChatInterface(
    fn=predict_chatbot,
    title="Greta",
    theme=theme
)

# Combine both interfaces into a single app
gr.TabbedInterface(
    [image_gradio_app, chatbot_gradio_app],
    tab_names=["image","chatbot"],
    theme=theme
).launch()