Create clean_dataset.py
Browse files- clean_dataset.py +447 -0
clean_dataset.py
ADDED
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import argparse
|
6 |
+
import logging
|
7 |
+
from tqdm import tqdm
|
8 |
+
import chardet
|
9 |
+
import csv
|
10 |
+
|
11 |
+
# Configure logging
|
12 |
+
logging.basicConfig(
|
13 |
+
level=logging.INFO,
|
14 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
15 |
+
handlers=[
|
16 |
+
logging.FileHandler("dataset_cleaner.log"),
|
17 |
+
logging.StreamHandler()
|
18 |
+
]
|
19 |
+
)
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
class SaaSDatasetCleaner:
|
24 |
+
"""
|
25 |
+
Class for cleaning and validating the SaaS sales conversation dataset.
|
26 |
+
Handles issues resulting from interrupted generations.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, input_file, output_file=None, chunk_size=1000, encoding='utf-8', skip_encoding_check=False):
|
30 |
+
"""
|
31 |
+
Initialize the cleaner.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
input_file: Path to the input CSV file
|
35 |
+
output_file: Path to save cleaned dataset (defaults to 'cleaned_' + input_file)
|
36 |
+
chunk_size: Number of rows to process at once
|
37 |
+
encoding: File encoding (defaults to utf-8)
|
38 |
+
skip_encoding_check: Whether to skip encoding detection and line-by-line processing
|
39 |
+
"""
|
40 |
+
self.input_file = input_file
|
41 |
+
self.output_file = output_file or f"cleaned_{os.path.basename(input_file)}"
|
42 |
+
self.chunk_size = chunk_size
|
43 |
+
self.encoding = encoding
|
44 |
+
self.skip_encoding_check = skip_encoding_check
|
45 |
+
self.stats = {
|
46 |
+
'total_rows': 0,
|
47 |
+
'valid_rows': 0,
|
48 |
+
'invalid_json': 0,
|
49 |
+
'missing_values': 0,
|
50 |
+
'invalid_embeddings': 0,
|
51 |
+
'duplicates': 0,
|
52 |
+
'encoding_errors': 0,
|
53 |
+
'recovered_rows': 0
|
54 |
+
}
|
55 |
+
|
56 |
+
# If not skipping encoding check, detect encoding
|
57 |
+
if not self.skip_encoding_check and not self.encoding:
|
58 |
+
self.detect_encoding()
|
59 |
+
|
60 |
+
# Get the columns and prepare for processing
|
61 |
+
self.initialize_columns()
|
62 |
+
|
63 |
+
def detect_encoding(self):
|
64 |
+
"""Detect the file encoding."""
|
65 |
+
logger.info("Detecting file encoding...")
|
66 |
+
|
67 |
+
# Read a sample of the file to detect encoding
|
68 |
+
with open(self.input_file, 'rb') as f:
|
69 |
+
sample = f.read(min(10000000, os.path.getsize(self.input_file))) # Read up to 10MB
|
70 |
+
|
71 |
+
result = chardet.detect(sample)
|
72 |
+
self.encoding = result['encoding']
|
73 |
+
confidence = result['confidence']
|
74 |
+
|
75 |
+
logger.info(f"Detected encoding: {self.encoding} with confidence: {confidence:.2f}")
|
76 |
+
|
77 |
+
# If confidence is low, try common encodings
|
78 |
+
if confidence < 0.7:
|
79 |
+
logger.warning(f"Low confidence in encoding detection. Will try multiple encodings.")
|
80 |
+
self.encoding = None # Will try multiple encodings later
|
81 |
+
|
82 |
+
def initialize_columns(self):
|
83 |
+
"""Initialize column information."""
|
84 |
+
# Try to read the header with different encodings if needed
|
85 |
+
encodings_to_try = ['utf-8'] if (self.skip_encoding_check or self.encoding) else ['utf-8', 'latin1', 'iso-8859-1', 'cp1252']
|
86 |
+
|
87 |
+
for enc in encodings_to_try:
|
88 |
+
try:
|
89 |
+
# Try to read just the header
|
90 |
+
with open(self.input_file, 'r', encoding=enc, errors='replace') as f:
|
91 |
+
reader = csv.reader(f)
|
92 |
+
self.columns = next(reader)
|
93 |
+
|
94 |
+
self.encoding = enc
|
95 |
+
logger.info(f"Successfully read header with encoding: {enc}")
|
96 |
+
|
97 |
+
# Identify embedding columns
|
98 |
+
self.embedding_cols = [col for col in self.columns if col.startswith('embedding_')]
|
99 |
+
logger.info(f"Found {len(self.embedding_cols)} embedding columns")
|
100 |
+
|
101 |
+
return
|
102 |
+
|
103 |
+
except Exception as e:
|
104 |
+
logger.warning(f"Failed to read header with encoding {enc}: {str(e)}")
|
105 |
+
|
106 |
+
# If we get here, all encodings failed
|
107 |
+
logger.error("Could not read column headers with any encoding")
|
108 |
+
self.columns = []
|
109 |
+
self.embedding_cols = []
|
110 |
+
|
111 |
+
def process_line_by_line(self):
|
112 |
+
"""Process the file line by line to handle encoding issues."""
|
113 |
+
logger.info("Processing file line by line to handle encoding issues...")
|
114 |
+
|
115 |
+
# Open the output file
|
116 |
+
with open(self.output_file, 'w', encoding='utf-8', newline='') as out_file:
|
117 |
+
writer = None # Will initialize after getting headers
|
118 |
+
|
119 |
+
# Process the input file
|
120 |
+
with open(self.input_file, 'rb') as in_file:
|
121 |
+
# Process line by line
|
122 |
+
line_count = 0
|
123 |
+
valid_count = 0
|
124 |
+
|
125 |
+
for line in tqdm(in_file, desc="Reading lines"):
|
126 |
+
line_count += 1
|
127 |
+
|
128 |
+
# Try to decode with multiple encodings
|
129 |
+
decoded_line = None
|
130 |
+
for enc in ['utf-8', 'latin1', 'iso-8859-1', 'cp1252']:
|
131 |
+
try:
|
132 |
+
decoded_line = line.decode(enc)
|
133 |
+
break
|
134 |
+
except UnicodeDecodeError:
|
135 |
+
continue
|
136 |
+
|
137 |
+
if decoded_line is None:
|
138 |
+
# Could not decode with any encoding, skip line
|
139 |
+
self.stats['encoding_errors'] += 1
|
140 |
+
continue
|
141 |
+
|
142 |
+
# Parse the CSV line
|
143 |
+
try:
|
144 |
+
reader = csv.reader([decoded_line])
|
145 |
+
row = next(reader)
|
146 |
+
|
147 |
+
# Initialize writer with headers if this is the first line
|
148 |
+
if line_count == 1:
|
149 |
+
writer = csv.writer(out_file)
|
150 |
+
writer.writerow(row) # Write headers
|
151 |
+
continue
|
152 |
+
|
153 |
+
# Basic validation - check number of columns
|
154 |
+
if len(row) != len(self.columns):
|
155 |
+
logger.debug(f"Line {line_count}: Column count mismatch. Expected {len(self.columns)}, got {len(row)}")
|
156 |
+
continue
|
157 |
+
|
158 |
+
# Write the row
|
159 |
+
writer.writerow(row)
|
160 |
+
valid_count += 1
|
161 |
+
|
162 |
+
except Exception as e:
|
163 |
+
logger.debug(f"Error processing line {line_count}: {str(e)}")
|
164 |
+
self.stats['encoding_errors'] += 1
|
165 |
+
|
166 |
+
self.stats['total_rows'] = line_count - 1 # Subtract header
|
167 |
+
self.stats['recovered_rows'] = valid_count
|
168 |
+
|
169 |
+
logger.info(f"Processed {line_count} lines, recovered {valid_count} valid rows")
|
170 |
+
logger.info(f"Found {self.stats['encoding_errors']} lines with encoding errors")
|
171 |
+
|
172 |
+
def _validate_json_fields(self, df):
|
173 |
+
"""Validate and clean JSON fields."""
|
174 |
+
# List of columns that should contain JSON
|
175 |
+
json_columns = ['scenario', 'conversation', 'probability_trajectory']
|
176 |
+
|
177 |
+
for col in json_columns:
|
178 |
+
if col not in df.columns:
|
179 |
+
continue
|
180 |
+
|
181 |
+
# Create a valid indicator
|
182 |
+
df[f'{col}_valid'] = True
|
183 |
+
|
184 |
+
# Check each value
|
185 |
+
for idx, value in enumerate(df[col]):
|
186 |
+
try:
|
187 |
+
if pd.isna(value):
|
188 |
+
df.at[idx, f'{col}_valid'] = False
|
189 |
+
self.stats['invalid_json'] += 1
|
190 |
+
continue
|
191 |
+
|
192 |
+
# Attempt to parse JSON
|
193 |
+
json.loads(value)
|
194 |
+
except:
|
195 |
+
df.at[idx, f'{col}_valid'] = False
|
196 |
+
self.stats['invalid_json'] += 1
|
197 |
+
|
198 |
+
# Create an overall valid flag
|
199 |
+
valid_flags = [f'{col}_valid' for col in json_columns if f'{col}_valid' in df.columns]
|
200 |
+
if valid_flags:
|
201 |
+
df['json_valid'] = df[valid_flags].all(axis=1)
|
202 |
+
else:
|
203 |
+
df['json_valid'] = True
|
204 |
+
|
205 |
+
# Clean up the temporary columns
|
206 |
+
for col in json_columns:
|
207 |
+
if f'{col}_valid' in df.columns:
|
208 |
+
df = df.drop(columns=[f'{col}_valid'])
|
209 |
+
|
210 |
+
return df
|
211 |
+
|
212 |
+
def _validate_embeddings(self, df):
|
213 |
+
"""Check if embeddings are valid."""
|
214 |
+
if not self.embedding_cols:
|
215 |
+
return df
|
216 |
+
|
217 |
+
# Check if the first embedding column has a value as a simple check
|
218 |
+
if 'embedding_0' in df.columns:
|
219 |
+
df['embeddings_valid'] = ~df['embedding_0'].isna()
|
220 |
+
else:
|
221 |
+
df['embeddings_valid'] = True
|
222 |
+
|
223 |
+
# Count invalid embeddings
|
224 |
+
self.stats['invalid_embeddings'] += (~df['embeddings_valid']).sum()
|
225 |
+
|
226 |
+
return df
|
227 |
+
|
228 |
+
def _check_missing_values(self, df):
|
229 |
+
"""Check for missing values in important columns."""
|
230 |
+
important_cols = [
|
231 |
+
'company_id', 'company_name', 'product_name', 'conversation_id',
|
232 |
+
'conversation', 'full_text', 'outcome'
|
233 |
+
]
|
234 |
+
|
235 |
+
# Filter to columns that actually exist
|
236 |
+
important_cols = [col for col in important_cols if col in df.columns]
|
237 |
+
|
238 |
+
if not important_cols:
|
239 |
+
df['missing_important'] = False
|
240 |
+
return df
|
241 |
+
|
242 |
+
# Create a flag for rows with missing important values
|
243 |
+
missing_flags = df[important_cols].isna().any(axis=1)
|
244 |
+
df['missing_important'] = missing_flags
|
245 |
+
|
246 |
+
# Count missing values
|
247 |
+
self.stats['missing_values'] += missing_flags.sum()
|
248 |
+
|
249 |
+
return df
|
250 |
+
|
251 |
+
def _flag_valid_rows(self, df):
|
252 |
+
"""Create a single flag for valid rows."""
|
253 |
+
# A row is valid if it has valid JSON, valid embeddings, and no missing important values
|
254 |
+
required_flags = []
|
255 |
+
|
256 |
+
if 'json_valid' in df.columns:
|
257 |
+
required_flags.append('json_valid')
|
258 |
+
|
259 |
+
if 'embeddings_valid' in df.columns:
|
260 |
+
required_flags.append('embeddings_valid')
|
261 |
+
|
262 |
+
if 'missing_important' in df.columns:
|
263 |
+
required_flags.append('~missing_important')
|
264 |
+
|
265 |
+
if required_flags:
|
266 |
+
if '~missing_important' in required_flags:
|
267 |
+
required_flags.remove('~missing_important')
|
268 |
+
if required_flags:
|
269 |
+
df['row_valid'] = df[required_flags].all(axis=1) & ~df['missing_important']
|
270 |
+
else:
|
271 |
+
df['row_valid'] = ~df['missing_important']
|
272 |
+
else:
|
273 |
+
df['row_valid'] = df[required_flags].all(axis=1)
|
274 |
+
else:
|
275 |
+
df['row_valid'] = True
|
276 |
+
|
277 |
+
# Update valid rows count
|
278 |
+
self.stats['valid_rows'] += df['row_valid'].sum()
|
279 |
+
|
280 |
+
return df
|
281 |
+
|
282 |
+
def _remove_duplicates(self, df):
|
283 |
+
"""Remove duplicate conversation IDs."""
|
284 |
+
if 'conversation_id' in df.columns:
|
285 |
+
# Check for duplicates
|
286 |
+
dup_mask = df.duplicated(subset=['conversation_id'], keep='first')
|
287 |
+
df['is_duplicate'] = dup_mask
|
288 |
+
|
289 |
+
# Count duplicates
|
290 |
+
self.stats['duplicates'] += dup_mask.sum()
|
291 |
+
else:
|
292 |
+
df['is_duplicate'] = False
|
293 |
+
|
294 |
+
return df
|
295 |
+
|
296 |
+
def clean_dataset(self):
|
297 |
+
"""
|
298 |
+
Clean the dataset by first fixing encoding issues, then cleaning the data.
|
299 |
+
"""
|
300 |
+
logger.info(f"Starting to clean dataset: {self.input_file}")
|
301 |
+
|
302 |
+
# Check if the file exists
|
303 |
+
if not os.path.exists(self.input_file):
|
304 |
+
logger.error(f"Input file not found: {self.input_file}")
|
305 |
+
return
|
306 |
+
|
307 |
+
# If we're not skipping encoding checks, process line by line
|
308 |
+
if not self.skip_encoding_check:
|
309 |
+
self.process_line_by_line()
|
310 |
+
intermediate_file = self.output_file
|
311 |
+
self.output_file = f"validated_{os.path.basename(self.input_file)}"
|
312 |
+
else:
|
313 |
+
logger.info("Skipping encoding check as requested")
|
314 |
+
# Use the input file directly as the intermediate file
|
315 |
+
intermediate_file = self.input_file
|
316 |
+
|
317 |
+
# Count rows in the file for progress tracking
|
318 |
+
with open(intermediate_file, 'r', encoding=self.encoding) as f:
|
319 |
+
self.stats['total_rows'] = sum(1 for _ in f) - 1 # Subtract header
|
320 |
+
self.stats['recovered_rows'] = self.stats['total_rows']
|
321 |
+
|
322 |
+
logger.info(f"Total rows to validate: {self.stats['total_rows']}")
|
323 |
+
|
324 |
+
# Now that we have a cleaned file with proper encoding, process it for data validation
|
325 |
+
logger.info("Beginning data validation on recovered rows...")
|
326 |
+
|
327 |
+
# Get the total number of rows for progress tracking
|
328 |
+
try:
|
329 |
+
total_rows = self.stats['recovered_rows']
|
330 |
+
logger.info(f"Total rows to validate: {total_rows}")
|
331 |
+
except Exception as e:
|
332 |
+
logger.error(f"Error counting rows: {str(e)}")
|
333 |
+
total_rows = 0
|
334 |
+
|
335 |
+
# Process the dataset in chunks
|
336 |
+
try:
|
337 |
+
# Create a reader - now with known proper encoding
|
338 |
+
# Use error_bad_lines=False for older pandas versions (renamed to on_bad_lines in newer versions)
|
339 |
+
reader = pd.read_csv(
|
340 |
+
intermediate_file,
|
341 |
+
chunksize=self.chunk_size,
|
342 |
+
encoding='utf-8',
|
343 |
+
low_memory=False, # Avoid dtype warnings
|
344 |
+
error_bad_lines=False # Skip bad lines (older parameter name)
|
345 |
+
)
|
346 |
+
|
347 |
+
# Create a header flag for the first chunk
|
348 |
+
first_chunk = True
|
349 |
+
|
350 |
+
# Process each chunk
|
351 |
+
with tqdm(total=total_rows, desc="Validating data") as pbar:
|
352 |
+
for chunk_num, chunk in enumerate(reader):
|
353 |
+
logger.debug(f"Processing chunk {chunk_num+1}")
|
354 |
+
|
355 |
+
# Run validation steps
|
356 |
+
chunk = self._validate_json_fields(chunk)
|
357 |
+
chunk = self._validate_embeddings(chunk)
|
358 |
+
chunk = self._check_missing_values(chunk)
|
359 |
+
chunk = self._remove_duplicates(chunk)
|
360 |
+
chunk = self._flag_valid_rows(chunk)
|
361 |
+
|
362 |
+
# Filter to valid rows only
|
363 |
+
valid_chunk = chunk[chunk['row_valid'] & ~chunk['is_duplicate']]
|
364 |
+
|
365 |
+
# Remove the validation columns
|
366 |
+
for col in ['json_valid', 'embeddings_valid', 'missing_important', 'row_valid', 'is_duplicate']:
|
367 |
+
if col in valid_chunk.columns:
|
368 |
+
valid_chunk = valid_chunk.drop(columns=[col])
|
369 |
+
|
370 |
+
# Write the cleaned chunk
|
371 |
+
valid_chunk.to_csv(
|
372 |
+
self.output_file,
|
373 |
+
mode='w' if first_chunk else 'a',
|
374 |
+
header=first_chunk,
|
375 |
+
index=False,
|
376 |
+
encoding='utf-8'
|
377 |
+
)
|
378 |
+
|
379 |
+
# Update the first chunk flag
|
380 |
+
if first_chunk:
|
381 |
+
first_chunk = False
|
382 |
+
|
383 |
+
# Update progress
|
384 |
+
pbar.update(len(chunk))
|
385 |
+
|
386 |
+
logger.info(f"Dataset cleaning complete. Results saved to {self.output_file}")
|
387 |
+
|
388 |
+
# Print statistics
|
389 |
+
logger.info(f"Cleaning Statistics:")
|
390 |
+
logger.info(f"- Total rows processed: {self.stats['total_rows']}")
|
391 |
+
logger.info(f"- Rows recovered from encoding issues: {self.stats['recovered_rows']}")
|
392 |
+
logger.info(f"- Encoding errors: {self.stats['encoding_errors']}")
|
393 |
+
logger.info(f"- Valid rows after validation: {self.stats['valid_rows']}")
|
394 |
+
logger.info(f"- Rows with invalid JSON: {self.stats['invalid_json']}")
|
395 |
+
logger.info(f"- Rows with missing values: {self.stats['missing_values']}")
|
396 |
+
logger.info(f"- Rows with invalid embeddings: {self.stats['invalid_embeddings']}")
|
397 |
+
logger.info(f"- Duplicate rows: {self.stats['duplicates']}")
|
398 |
+
|
399 |
+
# Create a summary file
|
400 |
+
with open(f"{self.output_file}_summary.txt", 'w') as f:
|
401 |
+
f.write("Dataset Cleaning Summary\n")
|
402 |
+
f.write("=======================\n\n")
|
403 |
+
f.write(f"Input file: {self.input_file}\n")
|
404 |
+
f.write(f"Output file: {self.output_file}\n\n")
|
405 |
+
f.write(f"Total rows processed: {self.stats['total_rows']}\n")
|
406 |
+
f.write(f"Rows recovered from encoding issues: {self.stats['recovered_rows']}\n")
|
407 |
+
f.write(f"Encoding errors: {self.stats['encoding_errors']}\n")
|
408 |
+
f.write(f"Valid rows after validation: {self.stats['valid_rows']}\n")
|
409 |
+
f.write(f"Rows with invalid JSON: {self.stats['invalid_json']}\n")
|
410 |
+
f.write(f"Rows with missing values: {self.stats['missing_values']}\n")
|
411 |
+
f.write(f"Rows with invalid embeddings: {self.stats['invalid_embeddings']}\n")
|
412 |
+
f.write(f"Duplicate rows: {self.stats['duplicates']}\n")
|
413 |
+
|
414 |
+
return self.stats
|
415 |
+
|
416 |
+
except Exception as e:
|
417 |
+
logger.error(f"Error validating dataset: {str(e)}")
|
418 |
+
raise e
|
419 |
+
|
420 |
+
def main():
|
421 |
+
"""Main function to run the dataset cleaner."""
|
422 |
+
parser = argparse.ArgumentParser(description="Clean and validate SaaS sales conversation dataset")
|
423 |
+
parser.add_argument("input_file", type=str, help="Path to the input CSV file")
|
424 |
+
parser.add_argument("--output_file", type=str, default=None,
|
425 |
+
help="Path to save cleaned dataset (defaults to 'cleaned_' + input_file)")
|
426 |
+
parser.add_argument("--chunk_size", type=int, default=1000,
|
427 |
+
help="Number of rows to process at once")
|
428 |
+
parser.add_argument("--encoding", type=str, default='utf-8',
|
429 |
+
help="File encoding (defaults to utf-8)")
|
430 |
+
parser.add_argument("--skip_encoding_check", action="store_true",
|
431 |
+
help="Skip encoding detection and line-by-line processing")
|
432 |
+
|
433 |
+
args = parser.parse_args()
|
434 |
+
|
435 |
+
# Create and run the cleaner
|
436 |
+
cleaner = SaaSDatasetCleaner(
|
437 |
+
input_file=args.input_file,
|
438 |
+
output_file=args.output_file,
|
439 |
+
chunk_size=args.chunk_size,
|
440 |
+
encoding=args.encoding,
|
441 |
+
skip_encoding_check=args.skip_encoding_check
|
442 |
+
)
|
443 |
+
|
444 |
+
cleaner.clean_dataset()
|
445 |
+
|
446 |
+
if __name__ == "__main__":
|
447 |
+
main()
|