File size: 4,638 Bytes
53dc0ac
 
 
 
 
 
 
8a677b0
7a2c982
 
53dc0ac
 
 
 
 
 
 
7a2c982
53dc0ac
 
 
 
 
7a2c982
53dc0ac
 
 
7a2c982
53dc0ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8531ccc
 
 
 
 
 
53dc0ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8531ccc
53dc0ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8531ccc
 
 
 
 
53dc0ac
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import sqlite3
from typing import List

from src.common import *


class DataLoader:
    active_db = "01_all_products_dataset"
    db_dir = os.path.join(data_dir, 'sqlite')
    db_file = os.path.join(db_dir, f"{active_db}.db")
    loaded = False

    @classmethod
    def set_db_name(cls, name: str):
        if name != cls.active_db:
            new_file = os.path.join(data_dir, 'sqlite', f"{name}.db")
            print(f"Switching database file from {cls.db_file} to {new_file}")
            cls.db_file = os.path.join(DataLoader.db_dir, f"{name}.db")
            DataLoader.load_data(reload=True)
            cls.active_db = name

    @staticmethod
    def current_db() -> str:
        return DataLoader.active_db[:-3]

    @staticmethod
    def available_dbs() -> List[str]:
        return [f[:-3] for f in os.listdir(DataLoader.db_dir) if ('products' in f) and f.endswith('.db')]

    @staticmethod
    def load_data(reload=False):
        if DataLoader.loaded and not reload:
            return

        # Wipe out any prior data
        Review.all = {}
        Feature.all = {}
        Product.all = {}
        Category.all = {}

        print(f"Loading {DataLoader.db_file}")
        con = sqlite3.connect(DataLoader.db_file)
        cur = con.cursor()

        categories = cur.execute('SELECT * FROM categories').fetchall()
        for c in categories:
            Category.all[c[0]] = Category(c[0], c[1])

        features = cur.execute('SELECT * FROM features').fetchall()
        for f in features:
            feat = Feature(f[0], f[1], Category.all[f[2]])
            Feature.all[f[0]] = feat
            Category.all[f[2]].features.append(feat)

        products = cur.execute('SELECT * FROM products').fetchall()
        for p in products:
            prod = Product(p[0], p[1], p[2], p[3], Category.all[p[4]])
            Product.all[p[0]] = prod
            Category.all[p[4]].products.append(prod)

        prod_feats = cur.execute('SELECT * FROM product_features').fetchall()
        for pf in prod_feats:
            Product.all[pf[1]].features.append(Feature.all[pf[2]])
            Feature.all[pf[2]].products.append(Product.all[pf[1]])

        reviews = cur.execute('SELECT * FROM reviews').fetchall()
        for r in reviews:
            rev = Review(r[0], r[2], r[3], Product.all[r[1]])
            Review.all[r[0]] = rev
            Product.all[r[1]].reviews.append(rev)

        print("Data loaded")
        DataLoader.loaded = True


class Category:
    all = {}

    @staticmethod
    def all_sorted():
        all_cats = list(Category.all.values())
        all_cats.sort(key=lambda x: x.name)
        return all_cats

    @staticmethod
    def by_name(name: str):
        all_cats = list(Category.all.values())
        for c in all_cats:
            if c.name == name:
                return c

    def __init__(self, id, name):
        self.id = id
        self.name = name
        self.features = []
        self.products = []

    @property
    def feature_count(self):
        return len(self.features)

    @property
    def product_count(self):
        return len(self.products)

    @property
    def singular_name(self):
        if self.name[-1] == "s":
            return self.name[:-1]  # Clip the s
        return self.name


class Feature:
    all = {}

    def __init__(self, id, name, category):
        self.id = id
        self.name = name
        self.category = category
        self.products = []

    @property
    def product_count(self):
        return len(self.products)

    def __repr__(self):
        return self.name


class Product:
    all = {}

    def __init__(self, id, name, description, price, category):
        self.id = id
        self.name = name
        self.description = description
        self.price = round(price, 2)
        self.category = category
        self.features = []
        self.reviews = []

    @property
    def feature_count(self):
        return len(self.features)

    @property
    def review_count(self):
        return len(self.reviews)

    @property
    def average_rating(self, decimals=2):
        if self.review_count == 0:
            return 0.0
        return float(round(sum([r.rating for r in self.reviews]) / self.review_count, decimals))

    @staticmethod
    def for_ids(ids: List[str]):
        return[Product.all[i] for i in ids]

    @staticmethod
    def all_as_list():
        return list(Product.all.values())


class Review:
    all = {}

    def __init__(self, id, rating, review_text, product):
        self.id = id
        self.rating = rating
        self.review_text = review_text
        self.product = product