alfraser commited on
Commit
7e353fe
·
1 Parent(s): 5c71a71

Added the script to shape the data for testing and the associated sqlite containg the test data

Browse files
data/sqlite/02_baseline_products_dataset.db CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a5ece96e2b662df011c9896e4c276053296e6ea28a8d207bdf19c5219734585d
3
  size 17260544
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c614f87d479e5ffd3dab7ec185811dc59c52a27a41eed7c5788f23674c6d77fd
3
  size 17260544
data/vector_stores/products_tvs_chroma/08d5b637-758f-478e-9873-811a4c46eaff/link_lists.bin DELETED
File without changes
src/data_synthesis/select_test_data_from_all_products.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is a script intended to be run offline, not as part of the application
3
+ This script takes the initially generated set of "all_products" and refines it
4
+ down to a more curated test set. The following are the goals in refining this down
5
+
6
+ - size and efficiency: reduce the amount of data to speed up training and iteration on the core part
7
+ of the project which is architecture assessment
8
+ - produce more realistic data: produce a more realistic distribution of data. For example the generation
9
+ process being statistical has clustered a lot of products around a 3.5 rating
10
+ - hold some data back for specific tests - e.g. if I want to test subsequent addition of data
11
+ """
12
+ import os
13
+ from random import randint, random, choices, shuffle
14
+ import shutil
15
+ import sqlite3
16
+
17
+ from typing import List
18
+
19
+ from src.common import data_dir
20
+
21
+
22
+ def src_db() -> str:
23
+ return os.path.join(data_dir, 'sqlite', '01_all_products_dataset.db')
24
+
25
+
26
+ def dst_db() -> str:
27
+ return os.path.join(data_dir, 'sqlite', '02_baseline_products_dataset.db')
28
+
29
+
30
+ def dst_conn() -> sqlite3.Connection:
31
+ return sqlite3.connect(dst_db())
32
+
33
+
34
+ def copy_all_products_db() -> None:
35
+ shutil.copy(src_db(), dst_db())
36
+
37
+
38
+ def join_strings_for_sql(items: List[str]) -> str:
39
+ """
40
+ Joins a list comma separated and enclosed in double quotes for use in an sql statement
41
+ """
42
+ return ', '.join([f'"{i}"' for i in items])
43
+
44
+
45
+ def execute_sqls(sql_statements: List[str]):
46
+ db_connection = dst_conn()
47
+ db_cursor = db_connection.cursor()
48
+ for s in sql_statements:
49
+ db_cursor.execute(s)
50
+ db_connection.commit()
51
+
52
+
53
+ def execute_select(sql_statement: str) -> List[List]:
54
+ db_connection = dst_conn()
55
+ db_cursor = db_connection.cursor()
56
+ db_cursor.execute(sql_statement)
57
+ return db_cursor.fetchall()
58
+
59
+
60
+ def drop_categories(keeping: List[str]) -> None:
61
+ # Drop from the tables sequentially to avoid foreign key constraint issues
62
+ keep_in = f"({join_strings_for_sql(keeping)})"
63
+ sql_statements = []
64
+ sql = f"delete from reviews where product_id in (select id from products where category_id in (select id from categories where name not in {keep_in}))"
65
+ sql_statements.append(sql)
66
+ sql = f"delete from product_features where product_id in (select id from products where category_id in (select id from categories where name not in {keep_in}))"
67
+ sql_statements.append(sql)
68
+ sql = f"delete from products where id in (select id from products where category_id in (select id from categories where name not in {keep_in}))"
69
+ sql_statements.append(sql)
70
+ sql = f"delete from features where category_id in (select id from categories where name not in {keep_in})"
71
+ sql_statements.append(sql)
72
+ sql = f"delete from categories where name not in {keep_in}"
73
+ sql_statements.append(sql)
74
+ execute_sqls(sql_statements)
75
+ print(f'Narrowed categories down to {len(keeping)}')
76
+
77
+
78
+ def drop_products(ids: List[int]) -> None:
79
+ ids_in = f"({join_strings_for_sql(ids)})"
80
+
81
+ sql_statements = []
82
+ sql = f"delete from reviews where product_id in {ids_in}"
83
+ sql_statements.append(sql)
84
+ sql = f"delete from product_features where product_id in {ids_in}"
85
+ sql_statements.append(sql)
86
+ sql = f"delete from products where id in {ids_in}"
87
+ sql_statements.append(sql)
88
+ execute_sqls(sql_statements)
89
+
90
+
91
+ def winnow_prods_per_category(min_count: int = 10, max_count: int = 25):
92
+ sql = "select name from categories"
93
+ categories = [r[0] for r in execute_select(sql)]
94
+ for c in categories:
95
+ target_prod_count = randint(min_count, max_count)
96
+ sql = f'select id from products where category_id in (select id from categories where name = "{c}")'
97
+ current_ids = [r[0] for r in execute_select(sql)]
98
+ if len(current_ids) > target_prod_count:
99
+ num_to_winnow = len(current_ids) - target_prod_count
100
+ ids_to_drop = choices(current_ids, k=num_to_winnow)
101
+ drop_products(ids_to_drop)
102
+ print(f'Winnowed {c} to {target_prod_count} products')
103
+ else:
104
+ print(f'{c} already at {len(current_ids)} - nothing to winnow')
105
+
106
+
107
+ def avg_rating(review_ratings1: List[List[int]], review_ratings2: List[List[int]]) -> float:
108
+ review_count = len(review_ratings1) + len(review_ratings2)
109
+ review_sum = sum([r[1] for r in review_ratings1]) + sum([r[1] for r in review_ratings2])
110
+ return review_sum / review_count
111
+
112
+
113
+ def get_review_ids_to_drop(review_ratings: List[List[int]], target_review_count: int, target_avg_rating: float) -> List[int]:
114
+ ids_to_drop = []
115
+ rated_lower_than_target = [r for r in review_ratings if r[1] >= target_avg_rating]
116
+ rated_higher_than_target = [r for r in review_ratings if r[1] < target_avg_rating]
117
+ while len(rated_higher_than_target) + len(rated_lower_than_target) > target_review_count:
118
+ if avg_rating(rated_higher_than_target, rated_lower_than_target) >= target_avg_rating:
119
+ if len(rated_higher_than_target) == 0:
120
+ break # Avoid getting stuck in a loop when there's no remaining way to reach the target
121
+ shuffle(rated_higher_than_target)
122
+ ids_to_drop.append(rated_higher_than_target.pop()[0])
123
+ else:
124
+ if len(rated_lower_than_target) == 0:
125
+ break # Avoid getting stuck in a loop when there's no remaining way to reach the target
126
+ shuffle(rated_lower_than_target)
127
+ ids_to_drop.append(rated_lower_than_target.pop()[0])
128
+ return ids_to_drop
129
+
130
+
131
+ def drop_reviews(review_ids: list[int]):
132
+ if len(review_ids) == 0:
133
+ print("Warning - got zero reviews to drop")
134
+ else:
135
+ ids_in = f'({", ".join([str(r) for r in review_ids])})'
136
+ sql = f'delete from reviews where id in {ids_in}'
137
+
138
+
139
+ def drop_reviews_to_balance_avg_rating(min_review_count: int = 5,
140
+ max_review_count: int = 30,
141
+ min_target_avg_rating: float = 3.4,
142
+ max_target_avg_rating: float = 4.9):
143
+ sql = "select id, name from products"
144
+ product_id_names = [(r[0], r[1]) for r in execute_select(sql)]
145
+ for prod_id, prod_name in product_id_names:
146
+ target_review_count = randint(min_review_count, max_review_count)
147
+ target_avg_rating = min_target_avg_rating + (random() * (max_target_avg_rating - min_target_avg_rating))
148
+ sql = f'select id, rating from reviews where product_id = {prod_id}'
149
+ review_ratings = execute_select(sql)
150
+ review_ids_to_drop = get_review_ids_to_drop(review_ratings, target_review_count, target_avg_rating)
151
+ print(f'Dropping {len(review_ids_to_drop)} reviews for {prod_name} trying to get to average review of ~{target_avg_rating:.1f}')
152
+ drop_reviews(review_ids_to_drop)
153
+
154
+
155
+ if __name__ == "__main__":
156
+ copy_all_products_db()
157
+
158
+ # Keep half the product categories - more recognisable ones
159
+ cats = [
160
+ "Dishwashers",
161
+ "TVs",
162
+ "Washing Machines",
163
+ "Vacuum Cleaners",
164
+ "Irons",
165
+ "Electric Kettles",
166
+ "Microwaves",
167
+ "Food Processors",
168
+ "Coffee Machines",
169
+ "Toasters"
170
+ ]
171
+ drop_categories(keeping=cats)
172
+ winnow_prods_per_category()
173
+ drop_reviews_to_balance_avg_rating()