File size: 3,615 Bytes
18f531e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# recommender_system.py
import pandas as pd
from surprise import Dataset, Reader, SVD
from surprise.model_selection import train_test_split
import joblib
from datetime import datetime, timedelta
import os
import random
from sqlalchemy.orm import Session
from typing import List
from orders.models import Order, Meal
from users.models import User


class MealRecommender:
    def __init__(self, db: Session):
        self.db = db
        self.model_path = 'recommendation_model.joblib'
        self.last_train_path = 'last_train_time.txt'
        self.retrain_interval = timedelta(days=1)
        self.algo = self.load_or_train_model()

    def fetch_data(self):
        orders = self.db.query(Order).all()
        return pd.DataFrame([(order.user_id, order.meal_id, order.quantity) for order in orders], 
                            columns=['user_id', 'meal_id', 'quantity'])

    def train_model(self):
        data = self.fetch_data()
        if data.empty:
            self.algo = None
            return None
        
        reader = Reader(rating_scale=(1, 5))
        dataset = Dataset.load_from_df(data[['user_id', 'meal_id', 'quantity']], reader)
    
        trainset = dataset.build_full_trainset()
        algo = SVD()
        algo.fit(trainset)
        
        joblib.dump(algo, self.model_path)
        self._update_last_train_time()
        return algo

    def load_or_train_model(self):
        try:
            if self._should_retrain():
                return self.train_model()
            return joblib.load(self.model_path)
        except FileNotFoundError:
            return self.train_model()

    def _should_retrain(self):
        if not os.path.exists(self.last_train_path):
            return True
        with open(self.last_train_path, 'r') as f:
            last_train_time = datetime.fromisoformat(f.read().strip())
        return datetime.now() - last_train_time > self.retrain_interval

    def _update_last_train_time(self):
        with open(self.last_train_path, 'w') as f:
            f.write(datetime.now().isoformat())

    def get_recommendations(self, user: User):
        if self._should_retrain():
            self.algo = self.train_model()

        if self.algo is None:
            return self.get_random_recommendations()

        all_meals = self.db.query(Meal).all()
        meal_ids = [meal.id for meal in all_meals]

        predictions = [self.algo.predict(str(user.id), str(meal_id)) for meal_id in meal_ids]
        sorted_predictions = sorted(predictions, key=lambda x: x.est, reverse=True)
        top_recommendations = self.db.query(Meal).filter(Meal.id.in_([int(pred.iid) for pred in sorted_predictions[:20]])).all()

        top_recommendations = self.adjust_for_preferences(user, top_recommendations)

        return top_recommendations[:5]  # Return top 5 recommendations

    def adjust_for_preferences(self, user: User, recommendations: List[Meal]) -> List[Meal]:
        preferences = user.preferences if user.preferences else []
        preference_scores = {meal.id: 0 for meal in recommendations}

        for meal in recommendations:
            for preferred in preferences:
                if preferred.lower() in meal.name.lower() or preferred.lower() in meal.description.lower():
                    preference_scores[meal.id] += 1

        sorted_recommendations = sorted(recommendations, key=lambda meal: preference_scores[meal.id], reverse=True)

        return sorted_recommendations

    def get_random_recommendations(self):
        all_meals = self.db.query(Meal).all()
        return random.sample(all_meals, min(5, len(all_meals)))