reab5555 commited on
Commit
ffaac50
1 Parent(s): de7ec56

Update clean.py

Browse files
Files changed (1) hide show
  1. clean.py +262 -267
clean.py CHANGED
@@ -1,267 +1,262 @@
1
- import re
2
-
3
- from pyspark.sql import SparkSession
4
- from pyspark.sql.functions import col, isnan, when, count, lower, regexp_replace, to_date, to_timestamp, udf, \
5
- levenshtein, array, lit, trim, size, coalesce
6
- from pyspark.sql.types import DoubleType, IntegerType, StringType, DateType, TimestampType, ArrayType
7
- from pyspark.sql.utils import AnalysisException
8
- import time
9
- from time import perf_counter
10
-
11
- # Constants
12
- EMPTY_THRESHOLD = 0.5
13
- LOW_COUNT_THRESHOLD = 2
14
- VALID_DATA_THRESHOLD = 0.5
15
-
16
- def print_dataframe_info(df, step=""):
17
- num_columns = len(df.columns)
18
- num_rows = df.count()
19
- num_cells = num_columns * num_rows
20
- print(f"{step}Dataframe info:")
21
- print(f" Number of columns: {num_columns}")
22
- print(f" Number of rows: {num_rows}")
23
- print(f" Total number of cells: {num_cells}")
24
-
25
-
26
- def check_and_normalize_column_headers(df):
27
- print("Checking and normalizing column headers...")
28
-
29
- for old_name in df.columns:
30
- # Create the new name using string manipulation
31
- new_name = old_name.lower().replace(' ', '_')
32
-
33
- # Remove any non-alphanumeric characters (excluding underscores)
34
- new_name = re.sub(r'[^0-9a-zA-Z_]', '', new_name)
35
-
36
- # Rename the column
37
- df = df.withColumnRenamed(old_name, new_name)
38
-
39
- print("Column names have been normalized.")
40
- return df
41
-
42
-
43
- def remove_empty_columns(df, threshold=EMPTY_THRESHOLD):
44
- print(f"Removing columns with less than {threshold * 100}% valid data...")
45
-
46
- # Calculate the percentage of non-null values for each column
47
- df_stats = df.select(
48
- [((count(when(col(c).isNotNull(), c)) / count('*')) >= threshold).alias(c) for c in df.columns])
49
- valid_columns = [c for c in df_stats.columns if df_stats.select(c).first()[0]]
50
-
51
- return df.select(valid_columns)
52
-
53
-
54
- def remove_empty_rows(df, threshold=EMPTY_THRESHOLD):
55
- print(f"Removing rows with less than {threshold * 100}% valid data...")
56
-
57
- # Count the number of non-null values for each row
58
- expr = sum([when(col(c).isNotNull(), lit(1)).otherwise(lit(0)) for c in df.columns])
59
- df_valid_count = df.withColumn('valid_count', expr)
60
-
61
- # Filter rows based on the threshold
62
- total_columns = len(df.columns)
63
- df_filtered = df_valid_count.filter(col('valid_count') >= threshold * total_columns)
64
-
65
- print('count of valid rows:', df_filtered.count())
66
-
67
- return df_filtered.drop('valid_count')
68
-
69
-
70
- def drop_rows_with_nas(df, threshold=VALID_DATA_THRESHOLD):
71
- print(f"Dropping rows with NAs for columns with more than {threshold * 100}% valid data...")
72
-
73
- # Calculate the percentage of non-null values for each column
74
- df_stats = df.select([((count(when(col(c).isNotNull(), c)) / count('*'))).alias(c) for c in df.columns])
75
-
76
- # Get columns with more than threshold valid data
77
- valid_columns = [c for c in df_stats.columns if df_stats.select(c).first()[0] > threshold]
78
-
79
- # Drop rows with NAs only for the valid columns
80
- for column in valid_columns:
81
- df = df.filter(col(column).isNotNull())
82
-
83
- return df
84
-
85
- def check_typos(df, column_name, threshold=2, top_n=100):
86
- # Check if the column is of StringType
87
- if not isinstance(df.schema[column_name].dataType, StringType):
88
- print(f"Skipping typo check for column {column_name} as it is not a string type.")
89
- return None
90
-
91
- print(f"Checking for typos in column: {column_name}")
92
-
93
- try:
94
- # Get value counts for the specific column
95
- value_counts = df.groupBy(column_name).count().orderBy("count", ascending=False)
96
-
97
- # Take top N most frequent values
98
- top_values = [row[column_name] for row in value_counts.limit(top_n).collect()]
99
-
100
- # Broadcast the top values to all nodes
101
- broadcast_top_values = df.sparkSession.sparkContext.broadcast(top_values)
102
-
103
- # Define UDF to find similar strings
104
- @udf(returnType=ArrayType(StringType()))
105
- def find_similar_strings(value):
106
- if value is None:
107
- return []
108
- similar = []
109
- for top_value in broadcast_top_values.value:
110
- if value != top_value and levenshtein(value, top_value) <= threshold:
111
- similar.append(top_value)
112
- return similar
113
-
114
- # Apply the UDF to the column
115
- df_with_typos = df.withColumn("possible_typos", find_similar_strings(col(column_name)))
116
-
117
- # Filter rows with possible typos and select only the relevant columns
118
- typos_df = df_with_typos.filter(size("possible_typos") > 0).select(column_name, "possible_typos")
119
-
120
- # Check if there are any potential typos
121
- typo_count = typos_df.count()
122
- if typo_count > 0:
123
- print(f"Potential typos found in column {column_name}: {typo_count}")
124
- typos_df.show(10, truncate=False)
125
- return typos_df
126
- else:
127
- print(f"No potential typos found in column {column_name}")
128
- return None
129
-
130
- except AnalysisException as e:
131
- print(f"Error analyzing column {column_name}: {str(e)}")
132
- return None
133
- except Exception as e:
134
- print(f"Unexpected error in check_typos for column {column_name}: {str(e)}")
135
- return None
136
-
137
-
138
- def transform_string_column(df, column_name):
139
- print(f"Transforming string column: {column_name}")
140
- # Lower case transformation (if applicable)
141
- df = df.withColumn(column_name, lower(col(column_name)))
142
- # Remove leading and trailing spaces
143
- df = df.withColumn(column_name, trim(col(column_name)))
144
- # Replace multiple spaces with a single space
145
- df = df.withColumn(column_name, regexp_replace(col(column_name), "\\s+", " "))
146
- # Remove special characters except those used in dates and times
147
- df = df.withColumn(column_name, regexp_replace(col(column_name), "[^a-zA-Z0-9\\s/:.-]", ""))
148
- return df
149
-
150
-
151
- def clean_column(df, column_name):
152
- print(f"Cleaning column: {column_name}")
153
- start_time = perf_counter()
154
- # Get the data type of the current column
155
- column_type = df.schema[column_name].dataType
156
-
157
- if isinstance(column_type, StringType):
158
- typos_df = check_typos(df, column_name)
159
- if typos_df is not None and typos_df.count() > 0:
160
- print(f"Detailed typos for column {column_name}:")
161
- typos_df.show(truncate=False)
162
- df = transform_string_column(df, column_name)
163
-
164
- elif isinstance(column_type, (DoubleType, IntegerType)):
165
- # For numeric columns, we'll do a simple null check
166
- df = df.withColumn(column_name, when(col(column_name).isNull(), lit(None)).otherwise(col(column_name)))
167
-
168
- end_time = perf_counter()
169
- print(f"Time taken to clean {column_name}: {end_time - start_time:.6f} seconds")
170
- return df
171
-
172
- # Update the remove_outliers function to work on a single column
173
- def remove_outliers(df, column):
174
- print(f"Removing outliers from column: {column}")
175
-
176
- stats = df.select(column).summary("25%", "75%").collect()
177
- q1 = float(stats[0][1])
178
- q3 = float(stats[1][1])
179
- iqr = q3 - q1
180
- lower_bound = q1 - 1.5 * iqr
181
- upper_bound = q3 + 1.5 * iqr
182
- df = df.filter((col(column) >= lower_bound) & (col(column) <= upper_bound))
183
-
184
- return df
185
-
186
- def calculate_nonconforming_cells(df):
187
- nonconforming_cells = {}
188
- for column in df.columns:
189
- nonconforming_count = df.filter(col(column).isNull() | isnan(column)).count()
190
- nonconforming_cells[column] = nonconforming_count
191
- return nonconforming_cells
192
-
193
- def get_numeric_columns(df):
194
- return [field.name for field in df.schema.fields if isinstance(field.dataType, (IntegerType, DoubleType))]
195
-
196
- def remove_duplicates_from_primary_key(df, primary_key_column):
197
- print(f"Removing duplicates based on primary key column: {primary_key_column}")
198
- return df.dropDuplicates([primary_key_column])
199
-
200
- def clean_data(spark, df, primary_key_column, progress):
201
- start_time = time.time()
202
- process_times = {}
203
-
204
- print("Starting data validation and cleaning...")
205
- print_dataframe_info(df, "Initial - ")
206
-
207
- # Calculate nonconforming cells before cleaning
208
- nonconforming_cells_before = calculate_nonconforming_cells(df)
209
-
210
- # Step 1: Normalize column headers
211
- progress(0.1, desc="Normalizing column headers")
212
- step_start_time = time.time()
213
- df = check_and_normalize_column_headers(df)
214
- process_times['Normalize headers'] = time.time() - step_start_time
215
-
216
- # Step 2: Remove empty columns
217
- progress(0.2, desc="Removing empty columns")
218
- step_start_time = time.time()
219
- df = remove_empty_columns(df)
220
- print('2) count of valid rows:', df.count())
221
- process_times['Remove empty columns'] = time.time() - step_start_time
222
-
223
- # Step 3: Remove empty rows
224
- progress(0.3, desc="Removing empty rows")
225
- step_start_time = time.time()
226
- df = remove_empty_rows(df)
227
- print('3) count of valid rows:', df.count())
228
- process_times['Remove empty rows'] = time.time() - step_start_time
229
-
230
- # Step 4: Drop rows with NAs for columns with more than 50% valid data
231
- progress(0.4, desc="Dropping rows with NAs")
232
- step_start_time = time.time()
233
- df = drop_rows_with_nas(df)
234
- print('4) count of valid rows:', df.count())
235
- process_times['Drop rows with NAs'] = time.time() - step_start_time
236
-
237
- # Step 5: Clean columns (including typo checking and string transformation)
238
- column_cleaning_times = {}
239
- total_columns = len(df.columns)
240
- for index, column in enumerate(df.columns):
241
- progress(0.5 + (0.2 * (index / total_columns)), desc=f"Cleaning column: {column}")
242
- column_start_time = time.time()
243
- df = clean_column(df, column)
244
- print('5) count of valid rows:', df.count())
245
- column_cleaning_times[f"Clean column: {column}"] = time.time() - column_start_time
246
- process_times.update(column_cleaning_times)
247
-
248
- # Step 6: Remove outliers from numeric columns (excluding primary key)
249
- progress(0.7, desc="Removing outliers")
250
- step_start_time = time.time()
251
- numeric_columns = get_numeric_columns(df)
252
- numeric_columns = [col for col in numeric_columns if col != primary_key_column]
253
- for column in numeric_columns:
254
- df = remove_outliers(df, column)
255
- print('6) count of valid rows:', df.count())
256
- process_times['Remove outliers'] = time.time() - step_start_time
257
-
258
- # Step 7: Remove duplicates from primary key column
259
- progress(0.8, desc="Removing duplicates from primary key")
260
- step_start_time = time.time()
261
- df = remove_duplicates_from_primary_key(df, primary_key_column)
262
- print('7) count of valid rows:', df.count())
263
-
264
- print("Cleaning process completed.")
265
- print_dataframe_info(df, "Final - ")
266
-
267
- return df, nonconforming_cells_before, process_times
 
1
+ import re
2
+
3
+ import time
4
+ from time import perf_counter
5
+
6
+ # Constants
7
+ EMPTY_THRESHOLD = 0.5
8
+ LOW_COUNT_THRESHOLD = 2
9
+ VALID_DATA_THRESHOLD = 0.5
10
+
11
+ def print_dataframe_info(df, step=""):
12
+ num_columns = len(df.columns)
13
+ num_rows = df.count()
14
+ num_cells = num_columns * num_rows
15
+ print(f"{step}Dataframe info:")
16
+ print(f" Number of columns: {num_columns}")
17
+ print(f" Number of rows: {num_rows}")
18
+ print(f" Total number of cells: {num_cells}")
19
+
20
+
21
+ def check_and_normalize_column_headers(df):
22
+ print("Checking and normalizing column headers...")
23
+
24
+ for old_name in df.columns:
25
+ # Create the new name using string manipulation
26
+ new_name = old_name.lower().replace(' ', '_')
27
+
28
+ # Remove any non-alphanumeric characters (excluding underscores)
29
+ new_name = re.sub(r'[^0-9a-zA-Z_]', '', new_name)
30
+
31
+ # Rename the column
32
+ df = df.withColumnRenamed(old_name, new_name)
33
+
34
+ print("Column names have been normalized.")
35
+ return df
36
+
37
+
38
+ def remove_empty_columns(df, threshold=EMPTY_THRESHOLD):
39
+ print(f"Removing columns with less than {threshold * 100}% valid data...")
40
+
41
+ # Calculate the percentage of non-null values for each column
42
+ df_stats = df.select(
43
+ [((count(when(col(c).isNotNull(), c)) / count('*')) >= threshold).alias(c) for c in df.columns])
44
+ valid_columns = [c for c in df_stats.columns if df_stats.select(c).first()[0]]
45
+
46
+ return df.select(valid_columns)
47
+
48
+
49
+ def remove_empty_rows(df, threshold=EMPTY_THRESHOLD):
50
+ print(f"Removing rows with less than {threshold * 100}% valid data...")
51
+
52
+ # Count the number of non-null values for each row
53
+ expr = sum([when(col(c).isNotNull(), lit(1)).otherwise(lit(0)) for c in df.columns])
54
+ df_valid_count = df.withColumn('valid_count', expr)
55
+
56
+ # Filter rows based on the threshold
57
+ total_columns = len(df.columns)
58
+ df_filtered = df_valid_count.filter(col('valid_count') >= threshold * total_columns)
59
+
60
+ print('count of valid rows:', df_filtered.count())
61
+
62
+ return df_filtered.drop('valid_count')
63
+
64
+
65
+ def drop_rows_with_nas(df, threshold=VALID_DATA_THRESHOLD):
66
+ print(f"Dropping rows with NAs for columns with more than {threshold * 100}% valid data...")
67
+
68
+ # Calculate the percentage of non-null values for each column
69
+ df_stats = df.select([((count(when(col(c).isNotNull(), c)) / count('*'))).alias(c) for c in df.columns])
70
+
71
+ # Get columns with more than threshold valid data
72
+ valid_columns = [c for c in df_stats.columns if df_stats.select(c).first()[0] > threshold]
73
+
74
+ # Drop rows with NAs only for the valid columns
75
+ for column in valid_columns:
76
+ df = df.filter(col(column).isNotNull())
77
+
78
+ return df
79
+
80
+ def check_typos(df, column_name, threshold=2, top_n=100):
81
+ # Check if the column is of StringType
82
+ if not isinstance(df.schema[column_name].dataType, StringType):
83
+ print(f"Skipping typo check for column {column_name} as it is not a string type.")
84
+ return None
85
+
86
+ print(f"Checking for typos in column: {column_name}")
87
+
88
+ try:
89
+ # Get value counts for the specific column
90
+ value_counts = df.groupBy(column_name).count().orderBy("count", ascending=False)
91
+
92
+ # Take top N most frequent values
93
+ top_values = [row[column_name] for row in value_counts.limit(top_n).collect()]
94
+
95
+ # Broadcast the top values to all nodes
96
+ broadcast_top_values = df.sparkSession.sparkContext.broadcast(top_values)
97
+
98
+ # Define UDF to find similar strings
99
+ @udf(returnType=ArrayType(StringType()))
100
+ def find_similar_strings(value):
101
+ if value is None:
102
+ return []
103
+ similar = []
104
+ for top_value in broadcast_top_values.value:
105
+ if value != top_value and levenshtein(value, top_value) <= threshold:
106
+ similar.append(top_value)
107
+ return similar
108
+
109
+ # Apply the UDF to the column
110
+ df_with_typos = df.withColumn("possible_typos", find_similar_strings(col(column_name)))
111
+
112
+ # Filter rows with possible typos and select only the relevant columns
113
+ typos_df = df_with_typos.filter(size("possible_typos") > 0).select(column_name, "possible_typos")
114
+
115
+ # Check if there are any potential typos
116
+ typo_count = typos_df.count()
117
+ if typo_count > 0:
118
+ print(f"Potential typos found in column {column_name}: {typo_count}")
119
+ typos_df.show(10, truncate=False)
120
+ return typos_df
121
+ else:
122
+ print(f"No potential typos found in column {column_name}")
123
+ return None
124
+
125
+ except AnalysisException as e:
126
+ print(f"Error analyzing column {column_name}: {str(e)}")
127
+ return None
128
+ except Exception as e:
129
+ print(f"Unexpected error in check_typos for column {column_name}: {str(e)}")
130
+ return None
131
+
132
+
133
+ def transform_string_column(df, column_name):
134
+ print(f"Transforming string column: {column_name}")
135
+ # Lower case transformation (if applicable)
136
+ df = df.withColumn(column_name, lower(col(column_name)))
137
+ # Remove leading and trailing spaces
138
+ df = df.withColumn(column_name, trim(col(column_name)))
139
+ # Replace multiple spaces with a single space
140
+ df = df.withColumn(column_name, regexp_replace(col(column_name), "\\s+", " "))
141
+ # Remove special characters except those used in dates and times
142
+ df = df.withColumn(column_name, regexp_replace(col(column_name), "[^a-zA-Z0-9\\s/:.-]", ""))
143
+ return df
144
+
145
+
146
+ def clean_column(df, column_name):
147
+ print(f"Cleaning column: {column_name}")
148
+ start_time = perf_counter()
149
+ # Get the data type of the current column
150
+ column_type = df.schema[column_name].dataType
151
+
152
+ if isinstance(column_type, StringType):
153
+ typos_df = check_typos(df, column_name)
154
+ if typos_df is not None and typos_df.count() > 0:
155
+ print(f"Detailed typos for column {column_name}:")
156
+ typos_df.show(truncate=False)
157
+ df = transform_string_column(df, column_name)
158
+
159
+ elif isinstance(column_type, (DoubleType, IntegerType)):
160
+ # For numeric columns, we'll do a simple null check
161
+ df = df.withColumn(column_name, when(col(column_name).isNull(), lit(None)).otherwise(col(column_name)))
162
+
163
+ end_time = perf_counter()
164
+ print(f"Time taken to clean {column_name}: {end_time - start_time:.6f} seconds")
165
+ return df
166
+
167
+ # Update the remove_outliers function to work on a single column
168
+ def remove_outliers(df, column):
169
+ print(f"Removing outliers from column: {column}")
170
+
171
+ stats = df.select(column).summary("25%", "75%").collect()
172
+ q1 = float(stats[0][1])
173
+ q3 = float(stats[1][1])
174
+ iqr = q3 - q1
175
+ lower_bound = q1 - 1.5 * iqr
176
+ upper_bound = q3 + 1.5 * iqr
177
+ df = df.filter((col(column) >= lower_bound) & (col(column) <= upper_bound))
178
+
179
+ return df
180
+
181
+ def calculate_nonconforming_cells(df):
182
+ nonconforming_cells = {}
183
+ for column in df.columns:
184
+ nonconforming_count = df.filter(col(column).isNull() | isnan(column)).count()
185
+ nonconforming_cells[column] = nonconforming_count
186
+ return nonconforming_cells
187
+
188
+ def get_numeric_columns(df):
189
+ return [field.name for field in df.schema.fields if isinstance(field.dataType, (IntegerType, DoubleType))]
190
+
191
+ def remove_duplicates_from_primary_key(df, primary_key_column):
192
+ print(f"Removing duplicates based on primary key column: {primary_key_column}")
193
+ return df.dropDuplicates([primary_key_column])
194
+
195
+ def clean_data(spark, df, primary_key_column, progress):
196
+ start_time = time.time()
197
+ process_times = {}
198
+
199
+ print("Starting data validation and cleaning...")
200
+ print_dataframe_info(df, "Initial - ")
201
+
202
+ # Calculate nonconforming cells before cleaning
203
+ nonconforming_cells_before = calculate_nonconforming_cells(df)
204
+
205
+ # Step 1: Normalize column headers
206
+ progress(0.1, desc="Normalizing column headers")
207
+ step_start_time = time.time()
208
+ df = check_and_normalize_column_headers(df)
209
+ process_times['Normalize headers'] = time.time() - step_start_time
210
+
211
+ # Step 2: Remove empty columns
212
+ progress(0.2, desc="Removing empty columns")
213
+ step_start_time = time.time()
214
+ df = remove_empty_columns(df)
215
+ print('2) count of valid rows:', df.count())
216
+ process_times['Remove empty columns'] = time.time() - step_start_time
217
+
218
+ # Step 3: Remove empty rows
219
+ progress(0.3, desc="Removing empty rows")
220
+ step_start_time = time.time()
221
+ df = remove_empty_rows(df)
222
+ print('3) count of valid rows:', df.count())
223
+ process_times['Remove empty rows'] = time.time() - step_start_time
224
+
225
+ # Step 4: Drop rows with NAs for columns with more than 50% valid data
226
+ progress(0.4, desc="Dropping rows with NAs")
227
+ step_start_time = time.time()
228
+ df = drop_rows_with_nas(df)
229
+ print('4) count of valid rows:', df.count())
230
+ process_times['Drop rows with NAs'] = time.time() - step_start_time
231
+
232
+ # Step 5: Clean columns (including typo checking and string transformation)
233
+ column_cleaning_times = {}
234
+ total_columns = len(df.columns)
235
+ for index, column in enumerate(df.columns):
236
+ progress(0.5 + (0.2 * (index / total_columns)), desc=f"Cleaning column: {column}")
237
+ column_start_time = time.time()
238
+ df = clean_column(df, column)
239
+ print('5) count of valid rows:', df.count())
240
+ column_cleaning_times[f"Clean column: {column}"] = time.time() - column_start_time
241
+ process_times.update(column_cleaning_times)
242
+
243
+ # Step 6: Remove outliers from numeric columns (excluding primary key)
244
+ progress(0.7, desc="Removing outliers")
245
+ step_start_time = time.time()
246
+ numeric_columns = get_numeric_columns(df)
247
+ numeric_columns = [col for col in numeric_columns if col != primary_key_column]
248
+ for column in numeric_columns:
249
+ df = remove_outliers(df, column)
250
+ print('6) count of valid rows:', df.count())
251
+ process_times['Remove outliers'] = time.time() - step_start_time
252
+
253
+ # Step 7: Remove duplicates from primary key column
254
+ progress(0.8, desc="Removing duplicates from primary key")
255
+ step_start_time = time.time()
256
+ df = remove_duplicates_from_primary_key(df, primary_key_column)
257
+ print('7) count of valid rows:', df.count())
258
+
259
+ print("Cleaning process completed.")
260
+ print_dataframe_info(df, "Final - ")
261
+
262
+ return df, nonconforming_cells_before, process_times