llm-arch / src /datatypes.py
alfraser's picture
Fixed bugs in the dataload process with referencing the new json folder and then looking up the available databases.
7a2c982
raw
history blame
4.47 kB
import sqlite3
from typing import List
from src.common import *
class DataLoader:
active_db = "all_products"
db_dir = os.path.join(data_dir, 'sqlite')
db_file = os.path.join(db_dir, f"{active_db}.db")
loaded = False
@classmethod
def set_db_name(cls, name: str):
if name != cls.active_db:
new_file = os.path.join(data_dir, 'sqlite', f"{name}.db")
print(f"Switching database file from {cls.db_file} to {new_file}")
cls.db_file = os.path.join(DataLoader.db_dir, f"{name}.db")
DataLoader.load_data(reload=True)
cls.active_db = name
@staticmethod
def current_db() -> str:
return DataLoader.active_db[:-3]
@staticmethod
def available_dbs() -> List[str]:
return [f[:-3] for f in os.listdir(DataLoader.db_dir) if ('products' in f) and f.endswith('.db')]
@staticmethod
def load_data(reload=False):
if DataLoader.loaded and not reload:
return
# Wipe out any prior data
Review.all = {}
Feature.all = {}
Product.all = {}
Category.all = {}
print(f"Loading {DataLoader.db_file}")
con = sqlite3.connect(DataLoader.db_file)
cur = con.cursor()
categories = cur.execute('SELECT * FROM categories').fetchall()
for c in categories:
Category.all[c[0]] = Category(c[0], c[1])
features = cur.execute('SELECT * FROM features').fetchall()
for f in features:
feat = Feature(f[0], f[1], Category.all[f[2]])
Feature.all[f[0]] = feat
Category.all[f[2]].features.append(feat)
products = cur.execute('SELECT * FROM products').fetchall()
for p in products:
prod = Product(p[0], p[1], p[2], p[3], Category.all[p[4]])
Product.all[p[0]] = prod
Category.all[p[4]].products.append(prod)
prod_feats = cur.execute('SELECT * FROM product_features').fetchall()
for pf in prod_feats:
Product.all[pf[1]].features.append(Feature.all[pf[2]])
Feature.all[pf[2]].products.append(Product.all[pf[1]])
reviews = cur.execute('SELECT * FROM reviews').fetchall()
for r in reviews:
rev = Review(r[0], r[2], r[3], Product.all[r[1]])
Review.all[r[0]] = rev
Product.all[r[1]].reviews.append(rev)
print("Data loaded")
DataLoader.loaded = True
class Category:
all = {}
@staticmethod
def all_sorted():
all_cats = list(Category.all.values())
all_cats.sort(key=lambda x: x.name)
return all_cats
@staticmethod
def by_name(name: str):
all_cats = list(Category.all.values())
for c in all_cats:
if c.name == name:
return c
def __init__(self, id, name):
self.id = id
self.name = name
self.features = []
self.products = []
@property
def feature_count(self):
return len(self.features)
@property
def product_count(self):
return len(self.products)
class Feature:
all = {}
def __init__(self, id, name, category):
self.id = id
self.name = name
self.category = category
self.products = []
@property
def product_count(self):
return len(self.products)
def __repr__(self):
return self.name
class Product:
all = {}
def __init__(self, id, name, description, price, category):
self.id = id
self.name = name
self.description = description
self.price = round(price, 2)
self.category = category
self.features = []
self.reviews = []
@property
def feature_count(self):
return len(self.features)
@property
def review_count(self):
return len(self.reviews)
@property
def average_rating(self, decimals=2):
if self.review_count == 0:
return 0.0
return float(round(sum([r.rating for r in self.reviews]) / self.review_count, decimals))
@staticmethod
def for_ids(ids: List[str]):
return[Product.all[i] for i in ids]
class Review:
all = {}
def __init__(self, id, rating, review_text, product):
self.id = id
self.rating = rating
self.review_text = review_text
self.product = product
if __name__ == "__main__":
DataLoader.load_data()
print('test')