emoji-predictor / app.py
Vincent Claes
update documentation
1ae0f82
raw
history blame contribute delete
No virus
4.66 kB
import gradio as gr
import torch
import os
from PIL import Image
from pathlib import Path
from more_itertools import chunked
from transformers import CLIPProcessor, CLIPModel
checkpoint = "vincentclaes/emoji-predictor"
x_, _, files = next(os.walk("./emojis"))
no_of_emojis = range(len(files))
emojis_as_images = [Image.open(f"emojis/{i}.png") for i in no_of_emojis]
K = 4
processor = CLIPProcessor.from_pretrained(checkpoint)
model = CLIPModel.from_pretrained(checkpoint)
def concat_images(*images):
"""Generate composite of all supplied images.
https://stackoverflow.com/a/71315656/1771155
"""
# Get the widest width.
width = max(image.width for image in images)
# Add up all the heights.
height = max(image.height for image in images)
# set the correct size of width and heigtht of composite.
composite = Image.new('RGB', (2*width, 2*height))
assert K == 4, "We expect 4 suggestions, other numbers won't work."
for i, image in enumerate(images):
if i == 0:
composite.paste(image, (0, 0))
elif i == 1:
composite.paste(image, (width, 0))
elif i == 2:
composite.paste(image, (0, height))
elif i == 3:
composite.paste(image, (width, height))
return composite
def get_emoji(text, model=model, processor=processor, emojis=emojis_as_images, K=4):
inputs = processor(text=text, images=emojis, return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs)
logits_per_text = outputs.logits_per_text
# we take the softmax to get the label probabilities
probs = logits_per_text.softmax(dim=1)
# top K number of options
predictions_suggestions_for_chunk = [torch.topk(prob, K).indices.tolist() for prob in probs][0]
predictions_suggestions_for_chunk
images = [Image.open(f"emojis/{i}.png") for i in predictions_suggestions_for_chunk]
images_concat = concat_images(*images)
return images_concat
text = gr.inputs.Textbox(placeholder="Enter a text and we will try to predict an emoji...")
title = "Predicting an Emoji"
description = """You provide a sentence and our few-shot fine tuned CLIP model will suggest 4 from the following emoji's:
\n❀️ 😍 πŸ˜‚ πŸ’• πŸ”₯ 😊 😎 ✨ πŸ’™ 😘 πŸ“· πŸ‡ΊπŸ‡Έ β˜€ πŸ’œ πŸ˜‰ πŸ’― 😁 πŸŽ„ πŸ“Έ 😜 ☹️ 😭 πŸ˜” 😑 πŸ’’ 😀 😳 πŸ™ƒ 😩 😠 πŸ™ˆ πŸ™„\n
"""
article = """
\n
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
\n
#### Let's connect on Linkedin: https://www.linkedin.com/in/vincent-claes-0b346337/
\n
# Context
I fine tuned Open Ai's CLIP model on both text (tweets) and images of emoji's!\n
The current model you can play with is fine-tuned on 15 samples per emoji.
- model: https://huggingface.co/vincentclaes/emoji-predictor \n
- dataset: https://huggingface.co/datasets/vincentclaes/emoji-predictor \n
- profile: https://huggingface.co/vincentclaes \n
# Precision
Below you can find a table with the precision for predictions and suggestions
for a range of samples per emoji we fine-tuned CLIP on.
### Prediction vs. Suggestion
- The column "Prediction" indicates the precision for predicting the right emoji.
- Since there can be some confusion about the right emoji for a tweet,
I also tried to present 4 suggestions. If 1 of the 4 suggestions is the same as the label,
I consider it a valid prediction. See the column "Suggestion".
- Randomly predicting an emoji would have a precision of 1/32 or 0.0325.
- Randomly suggesting an emoji would have a precision of 4/32 or 0.12.
| Samples | Prediction | Suggestion |
|--------- |------------ |------------ |
| 0 | 0.13 | 0.33 |
| 1 | 0.11 | 0.30 |
| 5 | 0.14 | 0.38 |
| 10 | 0.20 | 0.45 |
| 15 | 0.22 | 0.51 |
| 20 | 0.19 | 0.49 |
| 25 | 0.24 | 0.54 |
| 50 | 0.23 | 0.53 |
| 100 | 0.25 | 0.57 |
| 250 | 0.29 | 0.62 |
| 500 | 0.29 | 0.63 |
"""
examples = [
"I'm so happy for you!",
"I'm not feeling great today.",
"This makes me angry!",
"Can I follow you?",
"I'm so bored right now ...",
]
gr.Interface(fn=get_emoji, inputs=text, outputs=gr.Image(shape=(72,72)),
examples=examples, title=title, description=description,
article=article).launch()