FelixPhilip commited on
Commit
17c5050
·
1 Parent(s): 3388ab8
Files changed (1) hide show
  1. Oracle/deepfundingoracle.py +48 -1
Oracle/deepfundingoracle.py CHANGED
@@ -31,6 +31,7 @@ import time
31
  from sklearn.model_selection import train_test_split, GridSearchCV
32
  from sklearn.ensemble import RandomForestRegressor
33
  from sklearn.metrics import mean_squared_error
 
34
  import matplotlib.pyplot as plt
35
  import seaborn as sns
36
 
@@ -350,6 +351,42 @@ def clean_data(df):
350
  return df
351
 
352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  ##############################
354
  # RandomForest Regression
355
  ##############################
@@ -362,6 +399,12 @@ def train_predict_weight(df):
362
  target = "base_weight"
363
  feature_cols = [col for col in df.columns if col not in ["repo", "parent", "base_weight", "final_weight"]]
364
 
 
 
 
 
 
 
365
  X = df[feature_cols]
366
  y = df[target]
367
 
@@ -394,6 +437,11 @@ def train_predict_weight(df):
394
  print("[INFO] Feature importances:")
395
  print(importance_df)
396
 
 
 
 
 
 
397
  # Plot predictions vs. actual values
398
  plt.scatter(y_test, y_pred, alpha=0.5)
399
  plt.xlabel("Actual Base Weight")
@@ -433,4 +481,3 @@ if __name__ == "__main__":
433
  print("[INFO] Creating submission CSV...")
434
  create_submission_csv(df, output_file)
435
  print("[INFO] Process completed successfully.")
436
-
 
31
  from sklearn.model_selection import train_test_split, GridSearchCV
32
  from sklearn.ensemble import RandomForestRegressor
33
  from sklearn.metrics import mean_squared_error
34
+ from sklearn.preprocessing import StandardScaler
35
  import matplotlib.pyplot as plt
36
  import seaborn as sns
37
 
 
351
  return df
352
 
353
 
354
+ ##############################
355
+ # Feature Validation and Scaling
356
+ ##############################
357
+ def validate_features(df):
358
+ """
359
+ Validates and scales features to ensure they are meaningful for model training.
360
+ """
361
+ print("[INFO] Validating and scaling features...")
362
+ numeric_cols = df.select_dtypes(include=[np.number]).columns
363
+ scaler = StandardScaler()
364
+
365
+ # Log feature distributions
366
+ for col in numeric_cols:
367
+ print(f"[DEBUG] Feature '{col}' - Mean: {df[col].mean()}, Std: {df[col].std()}, Min: {df[col].min()}, Max: {df[col].max()}")
368
+
369
+ # Scale numeric features
370
+ df[numeric_cols] = scaler.fit_transform(df[numeric_cols])
371
+ print("[INFO] Features scaled successfully.")
372
+ return df
373
+
374
+ def validate_target(df):
375
+ """
376
+ Validates the target variable to ensure it has sufficient variance.
377
+ """
378
+ print("[INFO] Validating target variable 'base_weight'...")
379
+ target = "base_weight"
380
+ if target not in df.columns:
381
+ raise ValueError(f"Target variable '{target}' not found in DataFrame.")
382
+
383
+ variance = df[target].var()
384
+ print(f"[DEBUG] Target variable variance: {variance}")
385
+ if variance < 1e-6:
386
+ raise ValueError(f"Target variable '{target}' has insufficient variance.")
387
+ return df
388
+
389
+
390
  ##############################
391
  # RandomForest Regression
392
  ##############################
 
399
  target = "base_weight"
400
  feature_cols = [col for col in df.columns if col not in ["repo", "parent", "base_weight", "final_weight"]]
401
 
402
+ # Validate and scale features
403
+ df = validate_features(df)
404
+
405
+ # Validate target variable
406
+ df = validate_target(df)
407
+
408
  X = df[feature_cols]
409
  y = df[target]
410
 
 
437
  print("[INFO] Feature importances:")
438
  print(importance_df)
439
 
440
+ # Drop irrelevant features
441
+ irrelevant_features = importance_df[importance_df["Importance"] < 0.01]["Feature"].tolist()
442
+ print(f"[INFO] Dropping irrelevant features: {irrelevant_features}")
443
+ df.drop(columns=irrelevant_features, inplace=True)
444
+
445
  # Plot predictions vs. actual values
446
  plt.scatter(y_test, y_pred, alpha=0.5)
447
  plt.xlabel("Actual Base Weight")
 
481
  print("[INFO] Creating submission CSV...")
482
  create_submission_csv(df, output_file)
483
  print("[INFO] Process completed successfully.")