Spaces:
Runtime error
Runtime error
Added the script to shape the data for testing and the associated sqlite containg the test data
Browse files
data/sqlite/02_baseline_products_dataset.db
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 17260544
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c614f87d479e5ffd3dab7ec185811dc59c52a27a41eed7c5788f23674c6d77fd
|
3 |
size 17260544
|
data/vector_stores/products_tvs_chroma/08d5b637-758f-478e-9873-811a4c46eaff/link_lists.bin
DELETED
File without changes
|
src/data_synthesis/select_test_data_from_all_products.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file is a script intended to be run offline, not as part of the application
|
3 |
+
This script takes the initially generated set of "all_products" and refines it
|
4 |
+
down to a more curated test set. The following are the goals in refining this down
|
5 |
+
|
6 |
+
- size and efficiency: reduce the amount of data to speed up training and iteration on the core part
|
7 |
+
of the project which is architecture assessment
|
8 |
+
- produce more realistic data: produce a more realistic distribution of data. For example the generation
|
9 |
+
process being statistical has clustered a lot of products around a 3.5 rating
|
10 |
+
- hold some data back for specific tests - e.g. if I want to test subsequent addition of data
|
11 |
+
"""
|
12 |
+
import os
|
13 |
+
from random import randint, random, choices, shuffle
|
14 |
+
import shutil
|
15 |
+
import sqlite3
|
16 |
+
|
17 |
+
from typing import List
|
18 |
+
|
19 |
+
from src.common import data_dir
|
20 |
+
|
21 |
+
|
22 |
+
def src_db() -> str:
|
23 |
+
return os.path.join(data_dir, 'sqlite', '01_all_products_dataset.db')
|
24 |
+
|
25 |
+
|
26 |
+
def dst_db() -> str:
|
27 |
+
return os.path.join(data_dir, 'sqlite', '02_baseline_products_dataset.db')
|
28 |
+
|
29 |
+
|
30 |
+
def dst_conn() -> sqlite3.Connection:
|
31 |
+
return sqlite3.connect(dst_db())
|
32 |
+
|
33 |
+
|
34 |
+
def copy_all_products_db() -> None:
|
35 |
+
shutil.copy(src_db(), dst_db())
|
36 |
+
|
37 |
+
|
38 |
+
def join_strings_for_sql(items: List[str]) -> str:
|
39 |
+
"""
|
40 |
+
Joins a list comma separated and enclosed in double quotes for use in an sql statement
|
41 |
+
"""
|
42 |
+
return ', '.join([f'"{i}"' for i in items])
|
43 |
+
|
44 |
+
|
45 |
+
def execute_sqls(sql_statements: List[str]):
|
46 |
+
db_connection = dst_conn()
|
47 |
+
db_cursor = db_connection.cursor()
|
48 |
+
for s in sql_statements:
|
49 |
+
db_cursor.execute(s)
|
50 |
+
db_connection.commit()
|
51 |
+
|
52 |
+
|
53 |
+
def execute_select(sql_statement: str) -> List[List]:
|
54 |
+
db_connection = dst_conn()
|
55 |
+
db_cursor = db_connection.cursor()
|
56 |
+
db_cursor.execute(sql_statement)
|
57 |
+
return db_cursor.fetchall()
|
58 |
+
|
59 |
+
|
60 |
+
def drop_categories(keeping: List[str]) -> None:
|
61 |
+
# Drop from the tables sequentially to avoid foreign key constraint issues
|
62 |
+
keep_in = f"({join_strings_for_sql(keeping)})"
|
63 |
+
sql_statements = []
|
64 |
+
sql = f"delete from reviews where product_id in (select id from products where category_id in (select id from categories where name not in {keep_in}))"
|
65 |
+
sql_statements.append(sql)
|
66 |
+
sql = f"delete from product_features where product_id in (select id from products where category_id in (select id from categories where name not in {keep_in}))"
|
67 |
+
sql_statements.append(sql)
|
68 |
+
sql = f"delete from products where id in (select id from products where category_id in (select id from categories where name not in {keep_in}))"
|
69 |
+
sql_statements.append(sql)
|
70 |
+
sql = f"delete from features where category_id in (select id from categories where name not in {keep_in})"
|
71 |
+
sql_statements.append(sql)
|
72 |
+
sql = f"delete from categories where name not in {keep_in}"
|
73 |
+
sql_statements.append(sql)
|
74 |
+
execute_sqls(sql_statements)
|
75 |
+
print(f'Narrowed categories down to {len(keeping)}')
|
76 |
+
|
77 |
+
|
78 |
+
def drop_products(ids: List[int]) -> None:
|
79 |
+
ids_in = f"({join_strings_for_sql(ids)})"
|
80 |
+
|
81 |
+
sql_statements = []
|
82 |
+
sql = f"delete from reviews where product_id in {ids_in}"
|
83 |
+
sql_statements.append(sql)
|
84 |
+
sql = f"delete from product_features where product_id in {ids_in}"
|
85 |
+
sql_statements.append(sql)
|
86 |
+
sql = f"delete from products where id in {ids_in}"
|
87 |
+
sql_statements.append(sql)
|
88 |
+
execute_sqls(sql_statements)
|
89 |
+
|
90 |
+
|
91 |
+
def winnow_prods_per_category(min_count: int = 10, max_count: int = 25):
|
92 |
+
sql = "select name from categories"
|
93 |
+
categories = [r[0] for r in execute_select(sql)]
|
94 |
+
for c in categories:
|
95 |
+
target_prod_count = randint(min_count, max_count)
|
96 |
+
sql = f'select id from products where category_id in (select id from categories where name = "{c}")'
|
97 |
+
current_ids = [r[0] for r in execute_select(sql)]
|
98 |
+
if len(current_ids) > target_prod_count:
|
99 |
+
num_to_winnow = len(current_ids) - target_prod_count
|
100 |
+
ids_to_drop = choices(current_ids, k=num_to_winnow)
|
101 |
+
drop_products(ids_to_drop)
|
102 |
+
print(f'Winnowed {c} to {target_prod_count} products')
|
103 |
+
else:
|
104 |
+
print(f'{c} already at {len(current_ids)} - nothing to winnow')
|
105 |
+
|
106 |
+
|
107 |
+
def avg_rating(review_ratings1: List[List[int]], review_ratings2: List[List[int]]) -> float:
|
108 |
+
review_count = len(review_ratings1) + len(review_ratings2)
|
109 |
+
review_sum = sum([r[1] for r in review_ratings1]) + sum([r[1] for r in review_ratings2])
|
110 |
+
return review_sum / review_count
|
111 |
+
|
112 |
+
|
113 |
+
def get_review_ids_to_drop(review_ratings: List[List[int]], target_review_count: int, target_avg_rating: float) -> List[int]:
|
114 |
+
ids_to_drop = []
|
115 |
+
rated_lower_than_target = [r for r in review_ratings if r[1] >= target_avg_rating]
|
116 |
+
rated_higher_than_target = [r for r in review_ratings if r[1] < target_avg_rating]
|
117 |
+
while len(rated_higher_than_target) + len(rated_lower_than_target) > target_review_count:
|
118 |
+
if avg_rating(rated_higher_than_target, rated_lower_than_target) >= target_avg_rating:
|
119 |
+
if len(rated_higher_than_target) == 0:
|
120 |
+
break # Avoid getting stuck in a loop when there's no remaining way to reach the target
|
121 |
+
shuffle(rated_higher_than_target)
|
122 |
+
ids_to_drop.append(rated_higher_than_target.pop()[0])
|
123 |
+
else:
|
124 |
+
if len(rated_lower_than_target) == 0:
|
125 |
+
break # Avoid getting stuck in a loop when there's no remaining way to reach the target
|
126 |
+
shuffle(rated_lower_than_target)
|
127 |
+
ids_to_drop.append(rated_lower_than_target.pop()[0])
|
128 |
+
return ids_to_drop
|
129 |
+
|
130 |
+
|
131 |
+
def drop_reviews(review_ids: list[int]):
|
132 |
+
if len(review_ids) == 0:
|
133 |
+
print("Warning - got zero reviews to drop")
|
134 |
+
else:
|
135 |
+
ids_in = f'({", ".join([str(r) for r in review_ids])})'
|
136 |
+
sql = f'delete from reviews where id in {ids_in}'
|
137 |
+
|
138 |
+
|
139 |
+
def drop_reviews_to_balance_avg_rating(min_review_count: int = 5,
|
140 |
+
max_review_count: int = 30,
|
141 |
+
min_target_avg_rating: float = 3.4,
|
142 |
+
max_target_avg_rating: float = 4.9):
|
143 |
+
sql = "select id, name from products"
|
144 |
+
product_id_names = [(r[0], r[1]) for r in execute_select(sql)]
|
145 |
+
for prod_id, prod_name in product_id_names:
|
146 |
+
target_review_count = randint(min_review_count, max_review_count)
|
147 |
+
target_avg_rating = min_target_avg_rating + (random() * (max_target_avg_rating - min_target_avg_rating))
|
148 |
+
sql = f'select id, rating from reviews where product_id = {prod_id}'
|
149 |
+
review_ratings = execute_select(sql)
|
150 |
+
review_ids_to_drop = get_review_ids_to_drop(review_ratings, target_review_count, target_avg_rating)
|
151 |
+
print(f'Dropping {len(review_ids_to_drop)} reviews for {prod_name} trying to get to average review of ~{target_avg_rating:.1f}')
|
152 |
+
drop_reviews(review_ids_to_drop)
|
153 |
+
|
154 |
+
|
155 |
+
if __name__ == "__main__":
|
156 |
+
copy_all_products_db()
|
157 |
+
|
158 |
+
# Keep half the product categories - more recognisable ones
|
159 |
+
cats = [
|
160 |
+
"Dishwashers",
|
161 |
+
"TVs",
|
162 |
+
"Washing Machines",
|
163 |
+
"Vacuum Cleaners",
|
164 |
+
"Irons",
|
165 |
+
"Electric Kettles",
|
166 |
+
"Microwaves",
|
167 |
+
"Food Processors",
|
168 |
+
"Coffee Machines",
|
169 |
+
"Toasters"
|
170 |
+
]
|
171 |
+
drop_categories(keeping=cats)
|
172 |
+
winnow_prods_per_category()
|
173 |
+
drop_reviews_to_balance_avg_rating()
|