Spaces:
Sleeping
Sleeping
""" | |
Core emoji processing logic for the Emoji Mashup application. | |
""" | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
import requests | |
from PIL import Image | |
from io import BytesIO | |
import os | |
from config import CONFIG, EMBEDDING_MODELS | |
from utils import (logger, kitchen_txt_to_dict, | |
save_embeddings_to_pickle, load_embeddings_from_pickle, | |
get_embeddings_pickle_path) | |
class EmojiProcessor: | |
def __init__(self, model_name=None, model_key=None, use_cached_embeddings=True): | |
"""Initialize the emoji processor with the specified model. | |
Args: | |
model_name: Direct name of the sentence transformer model to use | |
model_key: Key from EMBEDDING_MODELS to use (takes precedence over model_name) | |
use_cached_embeddings: Whether to use cached embeddings from pickle files | |
""" | |
# Get model name from the key if provided | |
if model_key and model_key in EMBEDDING_MODELS: | |
model_name = EMBEDDING_MODELS[model_key]['id'] | |
elif not model_name: | |
model_name = CONFIG["model_name"] | |
logger.info(f"Loading model: {model_name}") | |
self.model = SentenceTransformer(model_name) | |
self.current_model_name = model_name | |
self.emotion_dict = {} | |
self.event_dict = {} | |
self.emotion_embeddings = {} | |
self.event_embeddings = {} | |
self.use_cached_embeddings = use_cached_embeddings | |
def load_emoji_dictionaries(self, emotion_file=CONFIG["emotion_file"], item_file=CONFIG["item_file"]): | |
"""Load emoji dictionaries from text files. | |
Args: | |
emotion_file: Path to the emotion emoji file | |
item_file: Path to the item emoji file | |
""" | |
logger.info("Loading emoji dictionaries") | |
self.emotion_dict = kitchen_txt_to_dict(emotion_file) | |
self.event_dict = kitchen_txt_to_dict(item_file) | |
# Load or compute embeddings | |
self._load_or_compute_embeddings() | |
def _load_or_compute_embeddings(self): | |
"""Load embeddings from pickle files if available, otherwise compute them.""" | |
if self.use_cached_embeddings: | |
# Try to load emotion embeddings | |
emotion_pickle_path = get_embeddings_pickle_path(self.current_model_name, "emotion") | |
loaded_emotion_embeddings = load_embeddings_from_pickle(emotion_pickle_path) | |
# Try to load event embeddings | |
event_pickle_path = get_embeddings_pickle_path(self.current_model_name, "event") | |
loaded_event_embeddings = load_embeddings_from_pickle(event_pickle_path) | |
# Check if we need to compute any embeddings | |
compute_emotion = loaded_emotion_embeddings is None | |
compute_event = loaded_event_embeddings is None | |
if not compute_emotion: | |
# Verify all emoji keys are present in loaded embeddings | |
for emoji in self.emotion_dict.keys(): | |
if emoji not in loaded_emotion_embeddings: | |
logger.info(f"Cached emotion embeddings missing emoji: {emoji}, will recompute") | |
compute_emotion = True | |
break | |
if not compute_emotion: | |
self.emotion_embeddings = loaded_emotion_embeddings | |
if not compute_event: | |
# Verify all emoji keys are present in loaded embeddings | |
for emoji in self.event_dict.keys(): | |
if emoji not in loaded_event_embeddings: | |
logger.info(f"Cached event embeddings missing emoji: {emoji}, will recompute") | |
compute_event = True | |
break | |
if not compute_event: | |
self.event_embeddings = loaded_event_embeddings | |
# Compute any missing embeddings | |
if compute_emotion: | |
logger.info(f"Computing emotion embeddings for model: {self.current_model_name}") | |
self.emotion_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.emotion_dict.items()} | |
# Save for future use | |
save_embeddings_to_pickle(self.emotion_embeddings, emotion_pickle_path) | |
if compute_event: | |
logger.info(f"Computing event embeddings for model: {self.current_model_name}") | |
self.event_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.event_dict.items()} | |
# Save for future use | |
save_embeddings_to_pickle(self.event_embeddings, event_pickle_path) | |
else: | |
# Compute embeddings without caching | |
logger.info("Computing embeddings for emoji dictionaries (no caching)") | |
self.emotion_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.emotion_dict.items()} | |
self.event_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.event_dict.items()} | |
def switch_model(self, model_key): | |
"""Switch to a different embedding model. | |
Args: | |
model_key: Key from EMBEDDING_MODELS to use | |
Returns: | |
True if model was switched successfully, False otherwise | |
""" | |
if model_key not in EMBEDDING_MODELS: | |
logger.error(f"Unknown model key: {model_key}") | |
return False | |
model_name = EMBEDDING_MODELS[model_key]['id'] | |
if model_name == self.current_model_name: | |
logger.info(f"Model {model_key} is already loaded") | |
return True | |
try: | |
logger.info(f"Switching to model: {model_name}") | |
self.model = SentenceTransformer(model_name) | |
self.current_model_name = model_name | |
# Load or recompute embeddings with new model | |
if self.emotion_dict and self.event_dict: | |
self._load_or_compute_embeddings() | |
return True | |
except Exception as e: | |
logger.error(f"Error switching model: {e}") | |
return False | |
def find_top_emojis(self, embedding, emoji_embeddings, top_n=1): | |
"""Find top matching emojis based on cosine similarity. | |
Args: | |
embedding: Sentence embedding to compare | |
emoji_embeddings: Dictionary of emoji embeddings | |
top_n: Number of top emojis to return | |
Returns: | |
List of top matching emojis | |
""" | |
similarities = [ | |
(emoji, cosine_similarity([embedding], [e_embed])[0][0]) | |
for emoji, e_embed in emoji_embeddings.items() | |
] | |
similarities.sort(key=lambda x: x[1], reverse=True) | |
return [emoji for emoji, _ in similarities[:top_n]] | |
def get_emoji_mashup_url(self, emoji1, emoji2, size=CONFIG["default_size"]): | |
"""Generate URL for emoji mashup. | |
Args: | |
emoji1: First emoji character | |
emoji2: Second emoji character | |
size: Image size in pixels | |
Returns: | |
URL for the emoji mashup | |
""" | |
return f"{CONFIG['emoji_kitchen_url'].format(emoji1=emoji1, emoji2=emoji2)}?size={size}" | |
def fetch_mashup_image(self, url): | |
"""Fetch emoji mashup image from URL. | |
Args: | |
url: URL of the emoji mashup image | |
Returns: | |
PIL Image object or None if fetch failed | |
""" | |
try: | |
response = requests.get(url) | |
if response.status_code == 200 and "image" in response.headers.get("Content-Type", ""): | |
return Image.open(BytesIO(response.content)) | |
else: | |
logger.warning(f"Failed to fetch image: Status code {response.status_code}") | |
return None | |
except Exception as e: | |
logger.error(f"Error fetching image: {e}") | |
return None | |
def sentence_to_emojis(self, sentence): | |
"""Process sentence to find matching emojis and generate mashup. | |
Args: | |
sentence: User input text | |
Returns: | |
Tuple of (emotion emoji, event emoji, mashup image) | |
""" | |
if not sentence.strip(): | |
return "β", "β", None | |
try: | |
embedding = self.model.encode(sentence) | |
top_emotion = self.find_top_emojis(embedding, self.emotion_embeddings, top_n=1)[0] | |
top_event = self.find_top_emojis(embedding, self.event_embeddings, top_n=1)[0] | |
mashup_url = self.get_emoji_mashup_url(top_emotion, top_event) | |
mashup_image = self.fetch_mashup_image(mashup_url) | |
return top_emotion, top_event, mashup_image | |
except Exception as e: | |
logger.error(f"Error processing sentence: {e}") | |
return "β", "β", None |