File size: 9,177 Bytes
cf957e4
 
 
 
 
 
 
 
 
cfb0d15
cf957e4
cfb0d15
 
 
 
cf957e4
 
cfb0d15
cf957e4
 
 
cfb0d15
 
 
cf957e4
cfb0d15
 
 
 
 
 
cf957e4
 
cfb0d15
cf957e4
 
 
 
cfb0d15
cf957e4
 
 
 
 
 
 
 
 
 
 
 
cfb0d15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf957e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
"""
Core emoji processing logic for the Emoji Mashup application.
"""

from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import requests
from PIL import Image
from io import BytesIO
import os

from config import CONFIG, EMBEDDING_MODELS
from utils import (logger, kitchen_txt_to_dict, 
                  save_embeddings_to_pickle, load_embeddings_from_pickle,
                  get_embeddings_pickle_path)

class EmojiProcessor:
    def __init__(self, model_name=None, model_key=None, use_cached_embeddings=True):
        """Initialize the emoji processor with the specified model.
        
        Args:
            model_name: Direct name of the sentence transformer model to use
            model_key: Key from EMBEDDING_MODELS to use (takes precedence over model_name)
            use_cached_embeddings: Whether to use cached embeddings from pickle files
        """
        # Get model name from the key if provided
        if model_key and model_key in EMBEDDING_MODELS:
            model_name = EMBEDDING_MODELS[model_key]['id']
        elif not model_name:
            model_name = CONFIG["model_name"]
            
        logger.info(f"Loading model: {model_name}")
        self.model = SentenceTransformer(model_name)
        self.current_model_name = model_name
        self.emotion_dict = {}
        self.event_dict = {}
        self.emotion_embeddings = {}
        self.event_embeddings = {}
        self.use_cached_embeddings = use_cached_embeddings

    def load_emoji_dictionaries(self, emotion_file=CONFIG["emotion_file"], item_file=CONFIG["item_file"]):
        """Load emoji dictionaries from text files.
        
        Args:
            emotion_file: Path to the emotion emoji file
            item_file: Path to the item emoji file
        """
        logger.info("Loading emoji dictionaries")
        self.emotion_dict = kitchen_txt_to_dict(emotion_file)
        self.event_dict = kitchen_txt_to_dict(item_file)
        
        # Load or compute embeddings
        self._load_or_compute_embeddings()
    
    def _load_or_compute_embeddings(self):
        """Load embeddings from pickle files if available, otherwise compute them."""
        if self.use_cached_embeddings:
            # Try to load emotion embeddings
            emotion_pickle_path = get_embeddings_pickle_path(self.current_model_name, "emotion")
            loaded_emotion_embeddings = load_embeddings_from_pickle(emotion_pickle_path)
            
            # Try to load event embeddings
            event_pickle_path = get_embeddings_pickle_path(self.current_model_name, "event")
            loaded_event_embeddings = load_embeddings_from_pickle(event_pickle_path)
            
            # Check if we need to compute any embeddings
            compute_emotion = loaded_emotion_embeddings is None
            compute_event = loaded_event_embeddings is None
            
            if not compute_emotion:
                # Verify all emoji keys are present in loaded embeddings
                for emoji in self.emotion_dict.keys():
                    if emoji not in loaded_emotion_embeddings:
                        logger.info(f"Cached emotion embeddings missing emoji: {emoji}, will recompute")
                        compute_emotion = True
                        break
                        
                if not compute_emotion:
                    self.emotion_embeddings = loaded_emotion_embeddings
            
            if not compute_event:
                # Verify all emoji keys are present in loaded embeddings
                for emoji in self.event_dict.keys():
                    if emoji not in loaded_event_embeddings:
                        logger.info(f"Cached event embeddings missing emoji: {emoji}, will recompute")
                        compute_event = True
                        break
                        
                if not compute_event:
                    self.event_embeddings = loaded_event_embeddings
            
            # Compute any missing embeddings
            if compute_emotion:
                logger.info(f"Computing emotion embeddings for model: {self.current_model_name}")
                self.emotion_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.emotion_dict.items()}
                # Save for future use
                save_embeddings_to_pickle(self.emotion_embeddings, emotion_pickle_path)
            
            if compute_event:
                logger.info(f"Computing event embeddings for model: {self.current_model_name}")
                self.event_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.event_dict.items()}
                # Save for future use
                save_embeddings_to_pickle(self.event_embeddings, event_pickle_path)
        else:
            # Compute embeddings without caching
            logger.info("Computing embeddings for emoji dictionaries (no caching)")
            self.emotion_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.emotion_dict.items()}
            self.event_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.event_dict.items()}
    
    def switch_model(self, model_key):
        """Switch to a different embedding model.
        
        Args:
            model_key: Key from EMBEDDING_MODELS to use
            
        Returns:
            True if model was switched successfully, False otherwise
        """
        if model_key not in EMBEDDING_MODELS:
            logger.error(f"Unknown model key: {model_key}")
            return False
            
        model_name = EMBEDDING_MODELS[model_key]['id']
        if model_name == self.current_model_name:
            logger.info(f"Model {model_key} is already loaded")
            return True
            
        try:
            logger.info(f"Switching to model: {model_name}")
            self.model = SentenceTransformer(model_name)
            self.current_model_name = model_name
            
            # Load or recompute embeddings with new model
            if self.emotion_dict and self.event_dict:
                self._load_or_compute_embeddings()
                
            return True
        except Exception as e:
            logger.error(f"Error switching model: {e}")
            return False

    def find_top_emojis(self, embedding, emoji_embeddings, top_n=1):
        """Find top matching emojis based on cosine similarity.
        
        Args:
            embedding: Sentence embedding to compare
            emoji_embeddings: Dictionary of emoji embeddings
            top_n: Number of top emojis to return
            
        Returns:
            List of top matching emojis
        """
        similarities = [
            (emoji, cosine_similarity([embedding], [e_embed])[0][0])
            for emoji, e_embed in emoji_embeddings.items()
        ]
        similarities.sort(key=lambda x: x[1], reverse=True)
        return [emoji for emoji, _ in similarities[:top_n]]

    def get_emoji_mashup_url(self, emoji1, emoji2, size=CONFIG["default_size"]):
        """Generate URL for emoji mashup.
        
        Args:
            emoji1: First emoji character
            emoji2: Second emoji character
            size: Image size in pixels
            
        Returns:
            URL for the emoji mashup
        """
        return f"{CONFIG['emoji_kitchen_url'].format(emoji1=emoji1, emoji2=emoji2)}?size={size}"

    def fetch_mashup_image(self, url):
        """Fetch emoji mashup image from URL.
        
        Args:
            url: URL of the emoji mashup image
            
        Returns:
            PIL Image object or None if fetch failed
        """
        try:
            response = requests.get(url)
            if response.status_code == 200 and "image" in response.headers.get("Content-Type", ""):
                return Image.open(BytesIO(response.content))
            else:
                logger.warning(f"Failed to fetch image: Status code {response.status_code}")
                return None
        except Exception as e:
            logger.error(f"Error fetching image: {e}")
            return None

    def sentence_to_emojis(self, sentence):
        """Process sentence to find matching emojis and generate mashup.
        
        Args:
            sentence: User input text
            
        Returns:
            Tuple of (emotion emoji, event emoji, mashup image)
        """
        if not sentence.strip():
            return "❓", "❓", None
            
        try:
            embedding = self.model.encode(sentence)
            
            top_emotion = self.find_top_emojis(embedding, self.emotion_embeddings, top_n=1)[0]
            top_event = self.find_top_emojis(embedding, self.event_embeddings, top_n=1)[0]
            
            mashup_url = self.get_emoji_mashup_url(top_emotion, top_event)
            mashup_image = self.fetch_mashup_image(mashup_url)
            
            return top_emotion, top_event, mashup_image
        except Exception as e:
            logger.error(f"Error processing sentence: {e}")
            return "❌", "❌", None