solanaexpert commited on
Commit
56ad009
·
verified ·
1 Parent(s): 8870588

Create randomforestML.py

Browse files
Files changed (1) hide show
  1. randomforestML.py +86 -0
randomforestML.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ from datetime import datetime, timedelta
5
+ from binance.client import Client
6
+ from sklearn.model_selection import train_test_split
7
+ from sklearn.ensemble import RandomForestClassifier
8
+ from sklearn.metrics import classification_report
9
+ import ta
10
+
11
+ # Connect to Binance (Fill your own API keys if live)
12
+ # client = Client(api_key, api_secret)
13
+ client = Client()
14
+
15
+ # File to store the historical data
16
+ DATA_FILE = "btc_data.csv"
17
+ symbol = "BTCUSDT"
18
+ interval = Client.KLINE_INTERVAL_4HOUR
19
+
20
+ # Load existing data or download fresh
21
+ if os.path.exists(DATA_FILE):
22
+ print("Loading existing data...")
23
+ df = pd.read_csv(DATA_FILE, index_col=0, parse_dates=True)
24
+ last_timestamp = df.index[-1]
25
+ # Binance gives data in 15min intervals, so move forward
26
+ start_time = last_timestamp + timedelta(minutes=15)
27
+ start_str = start_time.strftime("%d %B %Y %H:%M:%S")
28
+
29
+ print(f"Downloading new data from {start_str}...")
30
+ new_klines = client.get_historical_klines(symbol, interval, start_str)
31
+ if new_klines:
32
+ new_df = pd.DataFrame(new_klines, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume',
33
+ 'close_time', 'quote_av', 'trades', 'tb_base_av', 'tb_quote_av', 'ignore'])
34
+ new_df = new_df[['timestamp', 'open', 'high', 'low', 'close', 'volume']]
35
+ new_df[['open', 'high', 'low', 'close', 'volume']] = new_df[['open', 'high', 'low', 'close', 'volume']].astype(float)
36
+ new_df['timestamp'] = pd.to_datetime(new_df['timestamp'], unit='ms')
37
+ new_df = new_df.set_index('timestamp')
38
+
39
+ # Append and remove any duplicates (just in case)
40
+ df = pd.concat([df, new_df])
41
+ df = df[~df.index.duplicated(keep='first')]
42
+ df.to_csv(DATA_FILE)
43
+ else:
44
+ print("Downloading all data from scratch...")
45
+ klinesT = client.get_historical_klines(symbol, interval, "01 December 2021")
46
+ df = pd.DataFrame(klinesT, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume',
47
+ 'close_time', 'quote_av', 'trades', 'tb_base_av', 'tb_quote_av', 'ignore'])
48
+ df = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']]
49
+ df[['open', 'high', 'low', 'close', 'volume']] = df[['open', 'high', 'low', 'close', 'volume']].astype(float)
50
+ df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
51
+ df = df.set_index('timestamp')
52
+ df.to_csv(DATA_FILE)
53
+
54
+ # Feature Engineering: Add technical indicators
55
+ df['rsi'] = ta.momentum.RSIIndicator(df['close'], window=14).rsi()
56
+ df['sma_fast'] = df['close'].rolling(window=5).mean()
57
+ df['sma_slow'] = df['close'].rolling(window=20).mean()
58
+ df['macd'] = ta.trend.MACD(df['close']).macd()
59
+ df['ema'] = df['close'].ewm(span=10, adjust=False).mean()
60
+
61
+ # Create target: 1 if next close > current close, else 0
62
+ df['target'] = np.where(df['close'].shift(-1) > df['close'], 1, 0)
63
+
64
+ # Drop rows with NaN values
65
+ df = df.dropna()
66
+
67
+ # Features and Target
68
+ features = ['rsi', 'sma_fast', 'sma_slow', 'macd', 'ema']
69
+ X = df[features]
70
+ y = df['target']
71
+
72
+ # Train/Test split
73
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=False)
74
+
75
+ # Train Random Forest
76
+ model = RandomForestClassifier(n_estimators=100, random_state=42)
77
+ model.fit(X_train, y_train)
78
+
79
+ # Evaluate
80
+ y_pred = model.predict(X_test)
81
+ print(classification_report(y_test, y_pred))
82
+
83
+ # Predict next movement
84
+ latest_features = X.iloc[-1].values.reshape(1, -1)
85
+ predicted_direction = model.predict(latest_features)
86
+ print(f"Predicted next movement: {'UP' if predicted_direction[0] == 1 else 'DOWN'}")