DeepMostInnovations commited on
Commit
ab29941
·
verified ·
1 Parent(s): b8e56d4

Create clean_dataset.py

Browse files
Files changed (1) hide show
  1. clean_dataset.py +447 -0
clean_dataset.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import json
4
+ import os
5
+ import argparse
6
+ import logging
7
+ from tqdm import tqdm
8
+ import chardet
9
+ import csv
10
+
11
+ # Configure logging
12
+ logging.basicConfig(
13
+ level=logging.INFO,
14
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
15
+ handlers=[
16
+ logging.FileHandler("dataset_cleaner.log"),
17
+ logging.StreamHandler()
18
+ ]
19
+ )
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ class SaaSDatasetCleaner:
24
+ """
25
+ Class for cleaning and validating the SaaS sales conversation dataset.
26
+ Handles issues resulting from interrupted generations.
27
+ """
28
+
29
+ def __init__(self, input_file, output_file=None, chunk_size=1000, encoding='utf-8', skip_encoding_check=False):
30
+ """
31
+ Initialize the cleaner.
32
+
33
+ Args:
34
+ input_file: Path to the input CSV file
35
+ output_file: Path to save cleaned dataset (defaults to 'cleaned_' + input_file)
36
+ chunk_size: Number of rows to process at once
37
+ encoding: File encoding (defaults to utf-8)
38
+ skip_encoding_check: Whether to skip encoding detection and line-by-line processing
39
+ """
40
+ self.input_file = input_file
41
+ self.output_file = output_file or f"cleaned_{os.path.basename(input_file)}"
42
+ self.chunk_size = chunk_size
43
+ self.encoding = encoding
44
+ self.skip_encoding_check = skip_encoding_check
45
+ self.stats = {
46
+ 'total_rows': 0,
47
+ 'valid_rows': 0,
48
+ 'invalid_json': 0,
49
+ 'missing_values': 0,
50
+ 'invalid_embeddings': 0,
51
+ 'duplicates': 0,
52
+ 'encoding_errors': 0,
53
+ 'recovered_rows': 0
54
+ }
55
+
56
+ # If not skipping encoding check, detect encoding
57
+ if not self.skip_encoding_check and not self.encoding:
58
+ self.detect_encoding()
59
+
60
+ # Get the columns and prepare for processing
61
+ self.initialize_columns()
62
+
63
+ def detect_encoding(self):
64
+ """Detect the file encoding."""
65
+ logger.info("Detecting file encoding...")
66
+
67
+ # Read a sample of the file to detect encoding
68
+ with open(self.input_file, 'rb') as f:
69
+ sample = f.read(min(10000000, os.path.getsize(self.input_file))) # Read up to 10MB
70
+
71
+ result = chardet.detect(sample)
72
+ self.encoding = result['encoding']
73
+ confidence = result['confidence']
74
+
75
+ logger.info(f"Detected encoding: {self.encoding} with confidence: {confidence:.2f}")
76
+
77
+ # If confidence is low, try common encodings
78
+ if confidence < 0.7:
79
+ logger.warning(f"Low confidence in encoding detection. Will try multiple encodings.")
80
+ self.encoding = None # Will try multiple encodings later
81
+
82
+ def initialize_columns(self):
83
+ """Initialize column information."""
84
+ # Try to read the header with different encodings if needed
85
+ encodings_to_try = ['utf-8'] if (self.skip_encoding_check or self.encoding) else ['utf-8', 'latin1', 'iso-8859-1', 'cp1252']
86
+
87
+ for enc in encodings_to_try:
88
+ try:
89
+ # Try to read just the header
90
+ with open(self.input_file, 'r', encoding=enc, errors='replace') as f:
91
+ reader = csv.reader(f)
92
+ self.columns = next(reader)
93
+
94
+ self.encoding = enc
95
+ logger.info(f"Successfully read header with encoding: {enc}")
96
+
97
+ # Identify embedding columns
98
+ self.embedding_cols = [col for col in self.columns if col.startswith('embedding_')]
99
+ logger.info(f"Found {len(self.embedding_cols)} embedding columns")
100
+
101
+ return
102
+
103
+ except Exception as e:
104
+ logger.warning(f"Failed to read header with encoding {enc}: {str(e)}")
105
+
106
+ # If we get here, all encodings failed
107
+ logger.error("Could not read column headers with any encoding")
108
+ self.columns = []
109
+ self.embedding_cols = []
110
+
111
+ def process_line_by_line(self):
112
+ """Process the file line by line to handle encoding issues."""
113
+ logger.info("Processing file line by line to handle encoding issues...")
114
+
115
+ # Open the output file
116
+ with open(self.output_file, 'w', encoding='utf-8', newline='') as out_file:
117
+ writer = None # Will initialize after getting headers
118
+
119
+ # Process the input file
120
+ with open(self.input_file, 'rb') as in_file:
121
+ # Process line by line
122
+ line_count = 0
123
+ valid_count = 0
124
+
125
+ for line in tqdm(in_file, desc="Reading lines"):
126
+ line_count += 1
127
+
128
+ # Try to decode with multiple encodings
129
+ decoded_line = None
130
+ for enc in ['utf-8', 'latin1', 'iso-8859-1', 'cp1252']:
131
+ try:
132
+ decoded_line = line.decode(enc)
133
+ break
134
+ except UnicodeDecodeError:
135
+ continue
136
+
137
+ if decoded_line is None:
138
+ # Could not decode with any encoding, skip line
139
+ self.stats['encoding_errors'] += 1
140
+ continue
141
+
142
+ # Parse the CSV line
143
+ try:
144
+ reader = csv.reader([decoded_line])
145
+ row = next(reader)
146
+
147
+ # Initialize writer with headers if this is the first line
148
+ if line_count == 1:
149
+ writer = csv.writer(out_file)
150
+ writer.writerow(row) # Write headers
151
+ continue
152
+
153
+ # Basic validation - check number of columns
154
+ if len(row) != len(self.columns):
155
+ logger.debug(f"Line {line_count}: Column count mismatch. Expected {len(self.columns)}, got {len(row)}")
156
+ continue
157
+
158
+ # Write the row
159
+ writer.writerow(row)
160
+ valid_count += 1
161
+
162
+ except Exception as e:
163
+ logger.debug(f"Error processing line {line_count}: {str(e)}")
164
+ self.stats['encoding_errors'] += 1
165
+
166
+ self.stats['total_rows'] = line_count - 1 # Subtract header
167
+ self.stats['recovered_rows'] = valid_count
168
+
169
+ logger.info(f"Processed {line_count} lines, recovered {valid_count} valid rows")
170
+ logger.info(f"Found {self.stats['encoding_errors']} lines with encoding errors")
171
+
172
+ def _validate_json_fields(self, df):
173
+ """Validate and clean JSON fields."""
174
+ # List of columns that should contain JSON
175
+ json_columns = ['scenario', 'conversation', 'probability_trajectory']
176
+
177
+ for col in json_columns:
178
+ if col not in df.columns:
179
+ continue
180
+
181
+ # Create a valid indicator
182
+ df[f'{col}_valid'] = True
183
+
184
+ # Check each value
185
+ for idx, value in enumerate(df[col]):
186
+ try:
187
+ if pd.isna(value):
188
+ df.at[idx, f'{col}_valid'] = False
189
+ self.stats['invalid_json'] += 1
190
+ continue
191
+
192
+ # Attempt to parse JSON
193
+ json.loads(value)
194
+ except:
195
+ df.at[idx, f'{col}_valid'] = False
196
+ self.stats['invalid_json'] += 1
197
+
198
+ # Create an overall valid flag
199
+ valid_flags = [f'{col}_valid' for col in json_columns if f'{col}_valid' in df.columns]
200
+ if valid_flags:
201
+ df['json_valid'] = df[valid_flags].all(axis=1)
202
+ else:
203
+ df['json_valid'] = True
204
+
205
+ # Clean up the temporary columns
206
+ for col in json_columns:
207
+ if f'{col}_valid' in df.columns:
208
+ df = df.drop(columns=[f'{col}_valid'])
209
+
210
+ return df
211
+
212
+ def _validate_embeddings(self, df):
213
+ """Check if embeddings are valid."""
214
+ if not self.embedding_cols:
215
+ return df
216
+
217
+ # Check if the first embedding column has a value as a simple check
218
+ if 'embedding_0' in df.columns:
219
+ df['embeddings_valid'] = ~df['embedding_0'].isna()
220
+ else:
221
+ df['embeddings_valid'] = True
222
+
223
+ # Count invalid embeddings
224
+ self.stats['invalid_embeddings'] += (~df['embeddings_valid']).sum()
225
+
226
+ return df
227
+
228
+ def _check_missing_values(self, df):
229
+ """Check for missing values in important columns."""
230
+ important_cols = [
231
+ 'company_id', 'company_name', 'product_name', 'conversation_id',
232
+ 'conversation', 'full_text', 'outcome'
233
+ ]
234
+
235
+ # Filter to columns that actually exist
236
+ important_cols = [col for col in important_cols if col in df.columns]
237
+
238
+ if not important_cols:
239
+ df['missing_important'] = False
240
+ return df
241
+
242
+ # Create a flag for rows with missing important values
243
+ missing_flags = df[important_cols].isna().any(axis=1)
244
+ df['missing_important'] = missing_flags
245
+
246
+ # Count missing values
247
+ self.stats['missing_values'] += missing_flags.sum()
248
+
249
+ return df
250
+
251
+ def _flag_valid_rows(self, df):
252
+ """Create a single flag for valid rows."""
253
+ # A row is valid if it has valid JSON, valid embeddings, and no missing important values
254
+ required_flags = []
255
+
256
+ if 'json_valid' in df.columns:
257
+ required_flags.append('json_valid')
258
+
259
+ if 'embeddings_valid' in df.columns:
260
+ required_flags.append('embeddings_valid')
261
+
262
+ if 'missing_important' in df.columns:
263
+ required_flags.append('~missing_important')
264
+
265
+ if required_flags:
266
+ if '~missing_important' in required_flags:
267
+ required_flags.remove('~missing_important')
268
+ if required_flags:
269
+ df['row_valid'] = df[required_flags].all(axis=1) & ~df['missing_important']
270
+ else:
271
+ df['row_valid'] = ~df['missing_important']
272
+ else:
273
+ df['row_valid'] = df[required_flags].all(axis=1)
274
+ else:
275
+ df['row_valid'] = True
276
+
277
+ # Update valid rows count
278
+ self.stats['valid_rows'] += df['row_valid'].sum()
279
+
280
+ return df
281
+
282
+ def _remove_duplicates(self, df):
283
+ """Remove duplicate conversation IDs."""
284
+ if 'conversation_id' in df.columns:
285
+ # Check for duplicates
286
+ dup_mask = df.duplicated(subset=['conversation_id'], keep='first')
287
+ df['is_duplicate'] = dup_mask
288
+
289
+ # Count duplicates
290
+ self.stats['duplicates'] += dup_mask.sum()
291
+ else:
292
+ df['is_duplicate'] = False
293
+
294
+ return df
295
+
296
+ def clean_dataset(self):
297
+ """
298
+ Clean the dataset by first fixing encoding issues, then cleaning the data.
299
+ """
300
+ logger.info(f"Starting to clean dataset: {self.input_file}")
301
+
302
+ # Check if the file exists
303
+ if not os.path.exists(self.input_file):
304
+ logger.error(f"Input file not found: {self.input_file}")
305
+ return
306
+
307
+ # If we're not skipping encoding checks, process line by line
308
+ if not self.skip_encoding_check:
309
+ self.process_line_by_line()
310
+ intermediate_file = self.output_file
311
+ self.output_file = f"validated_{os.path.basename(self.input_file)}"
312
+ else:
313
+ logger.info("Skipping encoding check as requested")
314
+ # Use the input file directly as the intermediate file
315
+ intermediate_file = self.input_file
316
+
317
+ # Count rows in the file for progress tracking
318
+ with open(intermediate_file, 'r', encoding=self.encoding) as f:
319
+ self.stats['total_rows'] = sum(1 for _ in f) - 1 # Subtract header
320
+ self.stats['recovered_rows'] = self.stats['total_rows']
321
+
322
+ logger.info(f"Total rows to validate: {self.stats['total_rows']}")
323
+
324
+ # Now that we have a cleaned file with proper encoding, process it for data validation
325
+ logger.info("Beginning data validation on recovered rows...")
326
+
327
+ # Get the total number of rows for progress tracking
328
+ try:
329
+ total_rows = self.stats['recovered_rows']
330
+ logger.info(f"Total rows to validate: {total_rows}")
331
+ except Exception as e:
332
+ logger.error(f"Error counting rows: {str(e)}")
333
+ total_rows = 0
334
+
335
+ # Process the dataset in chunks
336
+ try:
337
+ # Create a reader - now with known proper encoding
338
+ # Use error_bad_lines=False for older pandas versions (renamed to on_bad_lines in newer versions)
339
+ reader = pd.read_csv(
340
+ intermediate_file,
341
+ chunksize=self.chunk_size,
342
+ encoding='utf-8',
343
+ low_memory=False, # Avoid dtype warnings
344
+ error_bad_lines=False # Skip bad lines (older parameter name)
345
+ )
346
+
347
+ # Create a header flag for the first chunk
348
+ first_chunk = True
349
+
350
+ # Process each chunk
351
+ with tqdm(total=total_rows, desc="Validating data") as pbar:
352
+ for chunk_num, chunk in enumerate(reader):
353
+ logger.debug(f"Processing chunk {chunk_num+1}")
354
+
355
+ # Run validation steps
356
+ chunk = self._validate_json_fields(chunk)
357
+ chunk = self._validate_embeddings(chunk)
358
+ chunk = self._check_missing_values(chunk)
359
+ chunk = self._remove_duplicates(chunk)
360
+ chunk = self._flag_valid_rows(chunk)
361
+
362
+ # Filter to valid rows only
363
+ valid_chunk = chunk[chunk['row_valid'] & ~chunk['is_duplicate']]
364
+
365
+ # Remove the validation columns
366
+ for col in ['json_valid', 'embeddings_valid', 'missing_important', 'row_valid', 'is_duplicate']:
367
+ if col in valid_chunk.columns:
368
+ valid_chunk = valid_chunk.drop(columns=[col])
369
+
370
+ # Write the cleaned chunk
371
+ valid_chunk.to_csv(
372
+ self.output_file,
373
+ mode='w' if first_chunk else 'a',
374
+ header=first_chunk,
375
+ index=False,
376
+ encoding='utf-8'
377
+ )
378
+
379
+ # Update the first chunk flag
380
+ if first_chunk:
381
+ first_chunk = False
382
+
383
+ # Update progress
384
+ pbar.update(len(chunk))
385
+
386
+ logger.info(f"Dataset cleaning complete. Results saved to {self.output_file}")
387
+
388
+ # Print statistics
389
+ logger.info(f"Cleaning Statistics:")
390
+ logger.info(f"- Total rows processed: {self.stats['total_rows']}")
391
+ logger.info(f"- Rows recovered from encoding issues: {self.stats['recovered_rows']}")
392
+ logger.info(f"- Encoding errors: {self.stats['encoding_errors']}")
393
+ logger.info(f"- Valid rows after validation: {self.stats['valid_rows']}")
394
+ logger.info(f"- Rows with invalid JSON: {self.stats['invalid_json']}")
395
+ logger.info(f"- Rows with missing values: {self.stats['missing_values']}")
396
+ logger.info(f"- Rows with invalid embeddings: {self.stats['invalid_embeddings']}")
397
+ logger.info(f"- Duplicate rows: {self.stats['duplicates']}")
398
+
399
+ # Create a summary file
400
+ with open(f"{self.output_file}_summary.txt", 'w') as f:
401
+ f.write("Dataset Cleaning Summary\n")
402
+ f.write("=======================\n\n")
403
+ f.write(f"Input file: {self.input_file}\n")
404
+ f.write(f"Output file: {self.output_file}\n\n")
405
+ f.write(f"Total rows processed: {self.stats['total_rows']}\n")
406
+ f.write(f"Rows recovered from encoding issues: {self.stats['recovered_rows']}\n")
407
+ f.write(f"Encoding errors: {self.stats['encoding_errors']}\n")
408
+ f.write(f"Valid rows after validation: {self.stats['valid_rows']}\n")
409
+ f.write(f"Rows with invalid JSON: {self.stats['invalid_json']}\n")
410
+ f.write(f"Rows with missing values: {self.stats['missing_values']}\n")
411
+ f.write(f"Rows with invalid embeddings: {self.stats['invalid_embeddings']}\n")
412
+ f.write(f"Duplicate rows: {self.stats['duplicates']}\n")
413
+
414
+ return self.stats
415
+
416
+ except Exception as e:
417
+ logger.error(f"Error validating dataset: {str(e)}")
418
+ raise e
419
+
420
+ def main():
421
+ """Main function to run the dataset cleaner."""
422
+ parser = argparse.ArgumentParser(description="Clean and validate SaaS sales conversation dataset")
423
+ parser.add_argument("input_file", type=str, help="Path to the input CSV file")
424
+ parser.add_argument("--output_file", type=str, default=None,
425
+ help="Path to save cleaned dataset (defaults to 'cleaned_' + input_file)")
426
+ parser.add_argument("--chunk_size", type=int, default=1000,
427
+ help="Number of rows to process at once")
428
+ parser.add_argument("--encoding", type=str, default='utf-8',
429
+ help="File encoding (defaults to utf-8)")
430
+ parser.add_argument("--skip_encoding_check", action="store_true",
431
+ help="Skip encoding detection and line-by-line processing")
432
+
433
+ args = parser.parse_args()
434
+
435
+ # Create and run the cleaner
436
+ cleaner = SaaSDatasetCleaner(
437
+ input_file=args.input_file,
438
+ output_file=args.output_file,
439
+ chunk_size=args.chunk_size,
440
+ encoding=args.encoding,
441
+ skip_encoding_check=args.skip_encoding_check
442
+ )
443
+
444
+ cleaner.clean_dataset()
445
+
446
+ if __name__ == "__main__":
447
+ main()