Spaces:
Runtime error
Runtime error
| import sqlite3 | |
| from typing import List | |
| from src.common import * | |
| class DataLoader: | |
| active_db = "all_products" | |
| db_dir = os.path.join(data_dir, 'sqlite') | |
| db_file = os.path.join(db_dir, f"{active_db}.db") | |
| loaded = False | |
| def set_db_name(cls, name: str): | |
| if name != cls.active_db: | |
| new_file = os.path.join(data_dir, 'sqlite', f"{name}.db") | |
| print(f"Switching database file from {cls.db_file} to {new_file}") | |
| cls.db_file = os.path.join(DataLoader.db_dir, f"{name}.db") | |
| DataLoader.load_data(reload=True) | |
| cls.active_db = name | |
| def current_db() -> str: | |
| return DataLoader.active_db[:-3] | |
| def available_dbs() -> List[str]: | |
| return [f[:-3] for f in os.listdir(DataLoader.db_dir) if ('products' in f) and f.endswith('.db')] | |
| def load_data(reload=False): | |
| if DataLoader.loaded and not reload: | |
| return | |
| # Wipe out any prior data | |
| Review.all = {} | |
| Feature.all = {} | |
| Product.all = {} | |
| Category.all = {} | |
| print(f"Loading {DataLoader.db_file}") | |
| con = sqlite3.connect(DataLoader.db_file) | |
| cur = con.cursor() | |
| categories = cur.execute('SELECT * FROM categories').fetchall() | |
| for c in categories: | |
| Category.all[c[0]] = Category(c[0], c[1]) | |
| features = cur.execute('SELECT * FROM features').fetchall() | |
| for f in features: | |
| feat = Feature(f[0], f[1], Category.all[f[2]]) | |
| Feature.all[f[0]] = feat | |
| Category.all[f[2]].features.append(feat) | |
| products = cur.execute('SELECT * FROM products').fetchall() | |
| for p in products: | |
| prod = Product(p[0], p[1], p[2], p[3], Category.all[p[4]]) | |
| Product.all[p[0]] = prod | |
| Category.all[p[4]].products.append(prod) | |
| prod_feats = cur.execute('SELECT * FROM product_features').fetchall() | |
| for pf in prod_feats: | |
| Product.all[pf[1]].features.append(Feature.all[pf[2]]) | |
| Feature.all[pf[2]].products.append(Product.all[pf[1]]) | |
| reviews = cur.execute('SELECT * FROM reviews').fetchall() | |
| for r in reviews: | |
| rev = Review(r[0], r[2], r[3], Product.all[r[1]]) | |
| Review.all[r[0]] = rev | |
| Product.all[r[1]].reviews.append(rev) | |
| print("Data loaded") | |
| DataLoader.loaded = True | |
| class Category: | |
| all = {} | |
| def all_sorted(): | |
| all_cats = list(Category.all.values()) | |
| all_cats.sort(key=lambda x: x.name) | |
| return all_cats | |
| def by_name(name: str): | |
| all_cats = list(Category.all.values()) | |
| for c in all_cats: | |
| if c.name == name: | |
| return c | |
| def __init__(self, id, name): | |
| self.id = id | |
| self.name = name | |
| self.features = [] | |
| self.products = [] | |
| def feature_count(self): | |
| return len(self.features) | |
| def product_count(self): | |
| return len(self.products) | |
| class Feature: | |
| all = {} | |
| def __init__(self, id, name, category): | |
| self.id = id | |
| self.name = name | |
| self.category = category | |
| self.products = [] | |
| def product_count(self): | |
| return len(self.products) | |
| def __repr__(self): | |
| return self.name | |
| class Product: | |
| all = {} | |
| def __init__(self, id, name, description, price, category): | |
| self.id = id | |
| self.name = name | |
| self.description = description | |
| self.price = round(price, 2) | |
| self.category = category | |
| self.features = [] | |
| self.reviews = [] | |
| def feature_count(self): | |
| return len(self.features) | |
| def review_count(self): | |
| return len(self.reviews) | |
| def average_rating(self, decimals=2): | |
| if self.review_count == 0: | |
| return 0.0 | |
| return float(round(sum([r.rating for r in self.reviews]) / self.review_count, decimals)) | |
| def for_ids(ids: List[str]): | |
| return[Product.all[i] for i in ids] | |
| class Review: | |
| all = {} | |
| def __init__(self, id, rating, review_text, product): | |
| self.id = id | |
| self.rating = rating | |
| self.review_text = review_text | |
| self.product = product | |
| if __name__ == "__main__": | |
| DataLoader.load_data() | |
| print('test') | |