Feelings_to_Emoji / emoji_processor.py
Dan Mo
Add script to generate and save embeddings for models
cfb0d15
"""
Core emoji processing logic for the Emoji Mashup application.
"""
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import requests
from PIL import Image
from io import BytesIO
import os
from config import CONFIG, EMBEDDING_MODELS
from utils import (logger, kitchen_txt_to_dict,
save_embeddings_to_pickle, load_embeddings_from_pickle,
get_embeddings_pickle_path)
class EmojiProcessor:
def __init__(self, model_name=None, model_key=None, use_cached_embeddings=True):
"""Initialize the emoji processor with the specified model.
Args:
model_name: Direct name of the sentence transformer model to use
model_key: Key from EMBEDDING_MODELS to use (takes precedence over model_name)
use_cached_embeddings: Whether to use cached embeddings from pickle files
"""
# Get model name from the key if provided
if model_key and model_key in EMBEDDING_MODELS:
model_name = EMBEDDING_MODELS[model_key]['id']
elif not model_name:
model_name = CONFIG["model_name"]
logger.info(f"Loading model: {model_name}")
self.model = SentenceTransformer(model_name)
self.current_model_name = model_name
self.emotion_dict = {}
self.event_dict = {}
self.emotion_embeddings = {}
self.event_embeddings = {}
self.use_cached_embeddings = use_cached_embeddings
def load_emoji_dictionaries(self, emotion_file=CONFIG["emotion_file"], item_file=CONFIG["item_file"]):
"""Load emoji dictionaries from text files.
Args:
emotion_file: Path to the emotion emoji file
item_file: Path to the item emoji file
"""
logger.info("Loading emoji dictionaries")
self.emotion_dict = kitchen_txt_to_dict(emotion_file)
self.event_dict = kitchen_txt_to_dict(item_file)
# Load or compute embeddings
self._load_or_compute_embeddings()
def _load_or_compute_embeddings(self):
"""Load embeddings from pickle files if available, otherwise compute them."""
if self.use_cached_embeddings:
# Try to load emotion embeddings
emotion_pickle_path = get_embeddings_pickle_path(self.current_model_name, "emotion")
loaded_emotion_embeddings = load_embeddings_from_pickle(emotion_pickle_path)
# Try to load event embeddings
event_pickle_path = get_embeddings_pickle_path(self.current_model_name, "event")
loaded_event_embeddings = load_embeddings_from_pickle(event_pickle_path)
# Check if we need to compute any embeddings
compute_emotion = loaded_emotion_embeddings is None
compute_event = loaded_event_embeddings is None
if not compute_emotion:
# Verify all emoji keys are present in loaded embeddings
for emoji in self.emotion_dict.keys():
if emoji not in loaded_emotion_embeddings:
logger.info(f"Cached emotion embeddings missing emoji: {emoji}, will recompute")
compute_emotion = True
break
if not compute_emotion:
self.emotion_embeddings = loaded_emotion_embeddings
if not compute_event:
# Verify all emoji keys are present in loaded embeddings
for emoji in self.event_dict.keys():
if emoji not in loaded_event_embeddings:
logger.info(f"Cached event embeddings missing emoji: {emoji}, will recompute")
compute_event = True
break
if not compute_event:
self.event_embeddings = loaded_event_embeddings
# Compute any missing embeddings
if compute_emotion:
logger.info(f"Computing emotion embeddings for model: {self.current_model_name}")
self.emotion_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.emotion_dict.items()}
# Save for future use
save_embeddings_to_pickle(self.emotion_embeddings, emotion_pickle_path)
if compute_event:
logger.info(f"Computing event embeddings for model: {self.current_model_name}")
self.event_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.event_dict.items()}
# Save for future use
save_embeddings_to_pickle(self.event_embeddings, event_pickle_path)
else:
# Compute embeddings without caching
logger.info("Computing embeddings for emoji dictionaries (no caching)")
self.emotion_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.emotion_dict.items()}
self.event_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.event_dict.items()}
def switch_model(self, model_key):
"""Switch to a different embedding model.
Args:
model_key: Key from EMBEDDING_MODELS to use
Returns:
True if model was switched successfully, False otherwise
"""
if model_key not in EMBEDDING_MODELS:
logger.error(f"Unknown model key: {model_key}")
return False
model_name = EMBEDDING_MODELS[model_key]['id']
if model_name == self.current_model_name:
logger.info(f"Model {model_key} is already loaded")
return True
try:
logger.info(f"Switching to model: {model_name}")
self.model = SentenceTransformer(model_name)
self.current_model_name = model_name
# Load or recompute embeddings with new model
if self.emotion_dict and self.event_dict:
self._load_or_compute_embeddings()
return True
except Exception as e:
logger.error(f"Error switching model: {e}")
return False
def find_top_emojis(self, embedding, emoji_embeddings, top_n=1):
"""Find top matching emojis based on cosine similarity.
Args:
embedding: Sentence embedding to compare
emoji_embeddings: Dictionary of emoji embeddings
top_n: Number of top emojis to return
Returns:
List of top matching emojis
"""
similarities = [
(emoji, cosine_similarity([embedding], [e_embed])[0][0])
for emoji, e_embed in emoji_embeddings.items()
]
similarities.sort(key=lambda x: x[1], reverse=True)
return [emoji for emoji, _ in similarities[:top_n]]
def get_emoji_mashup_url(self, emoji1, emoji2, size=CONFIG["default_size"]):
"""Generate URL for emoji mashup.
Args:
emoji1: First emoji character
emoji2: Second emoji character
size: Image size in pixels
Returns:
URL for the emoji mashup
"""
return f"{CONFIG['emoji_kitchen_url'].format(emoji1=emoji1, emoji2=emoji2)}?size={size}"
def fetch_mashup_image(self, url):
"""Fetch emoji mashup image from URL.
Args:
url: URL of the emoji mashup image
Returns:
PIL Image object or None if fetch failed
"""
try:
response = requests.get(url)
if response.status_code == 200 and "image" in response.headers.get("Content-Type", ""):
return Image.open(BytesIO(response.content))
else:
logger.warning(f"Failed to fetch image: Status code {response.status_code}")
return None
except Exception as e:
logger.error(f"Error fetching image: {e}")
return None
def sentence_to_emojis(self, sentence):
"""Process sentence to find matching emojis and generate mashup.
Args:
sentence: User input text
Returns:
Tuple of (emotion emoji, event emoji, mashup image)
"""
if not sentence.strip():
return "❓", "❓", None
try:
embedding = self.model.encode(sentence)
top_emotion = self.find_top_emojis(embedding, self.emotion_embeddings, top_n=1)[0]
top_event = self.find_top_emojis(embedding, self.event_embeddings, top_n=1)[0]
mashup_url = self.get_emoji_mashup_url(top_emotion, top_event)
mashup_image = self.fetch_mashup_image(mashup_url)
return top_emotion, top_event, mashup_image
except Exception as e:
logger.error(f"Error processing sentence: {e}")
return "❌", "❌", None