duongtruongbinh commited on
Commit
439a5cd
·
verified ·
1 Parent(s): 6e16701

Update index.html

Browse files
Files changed (1) hide show
  1. index.html +905 -18
index.html CHANGED
@@ -1,19 +1,906 @@
1
- <!doctype html>
2
- <html>
3
- <head>
4
- <meta charset="utf-8" />
5
- <meta name="viewport" content="width=device-width" />
6
- <title>My static Space</title>
7
- <link rel="stylesheet" href="style.css" />
8
- </head>
9
- <body>
10
- <div class="card">
11
- <h1>Welcome to your static Space!</h1>
12
- <p>You can modify this app directly by editing <i>index.html</i> in the Files and versions tab.</p>
13
- <p>
14
- Also don't forget to check the
15
- <a href="https://huggingface.co/docs/hub/spaces" target="_blank">Spaces documentation</a>.
16
- </p>
17
- </div>
18
- </body>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  </html>
 
1
+ <!DOCTYPE html>
2
+ <html lang="vi">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>MLP Interactive Visualization</title>
7
+ <script src="https://cdn.tailwindcss.com"></script>
8
+ <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest/dist/tf.min.js"></script>
9
+ <link rel="preconnect" href="https://fonts.googleapis.com">
10
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
11
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&display=swap" rel="stylesheet">
12
+ <style>
13
+ body { font-family: 'Inter', sans-serif; background-color: #f1f5f9; }
14
+
15
+ /* Custom class for active class selection buttons */
16
+ .button-active-blue {
17
+ background-color: #2563eb !important;
18
+ color: white !important;
19
+ border-color: #2563eb !important;
20
+ box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
21
+ }
22
+ .button-active-red {
23
+ background-color: #dc2626 !important;
24
+ color: white !important;
25
+ border-color: #dc2626 !important;
26
+ box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
27
+ }
28
+
29
+ /* Cursor styles for data input */
30
+ #plotCanvas { touch-action: none; }
31
+ #plotCanvas.cursor-class-0 { cursor: url('data:image/svg+xml;utf8,<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="%233b82f6" stroke="white" stroke-width="2"><circle cx="12" cy="12" r="8"/><path d="M12 2v4M12 18v4M22 12h-4M6 12H2"/></svg>') 12 12, crosshair; }
32
+ #plotCanvas.cursor-class-1 { cursor: url('data:image/svg+xml;utf8,<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="%23ef4444" stroke="white" stroke-width="2"><circle cx="12" cy="12" r="8"/><path d="M12 2v4M12 18v4M22 12h-4M6 12H2"/></svg>') 12 12, crosshair; }
33
+
34
+ /* Scrollbar styling */
35
+ #hiddenLayersConfigContainer::-webkit-scrollbar { width: 6px; }
36
+ #hiddenLayersConfigContainer::-webkit-scrollbar-track { background: #e2e8f0; border-radius: 8px; }
37
+ #hiddenLayersConfigContainer::-webkit-scrollbar-thumb { background: #94a3b8; border-radius: 8px; }
38
+ #hiddenLayersConfigContainer::-webkit-scrollbar-thumb:hover { background: #64748b; }
39
+
40
+ /* Network Visualization SVG styles */
41
+ .neuron {
42
+ stroke-width: 1.5;
43
+ transition: stroke-width 0.2s ease-in-out;
44
+ }
45
+ .neuron:hover {
46
+ stroke-width: 4;
47
+ }
48
+ .neuron.input { fill: #60a5fa; stroke: #2563eb; }
49
+ .neuron.hidden-0 { fill: #818cf8; stroke: #4f46e5; }
50
+ .neuron.hidden-1 { fill: #a78bfa; stroke: #7c3aed; }
51
+ .neuron.hidden-2 { fill: #c084fc; stroke: #9333ea; }
52
+ .neuron.hidden-3 { fill: #e879f9; stroke: #c026d3; }
53
+ .neuron.hidden-other { fill: #f472b6; stroke: #db2777; }
54
+ .neuron.output { fill: #f87171; stroke: #dc2626; }
55
+
56
+ .connection {
57
+ stroke: #cbd5e1;
58
+ stroke-width: 0.75;
59
+ transition: stroke-opacity 0.2s;
60
+ }
61
+
62
+ .layer-label {
63
+ font-size: 11px;
64
+ font-weight: 600;
65
+ fill: #475569;
66
+ text-anchor: middle;
67
+ }
68
+ .neuron-count-label {
69
+ font-size: 10px;
70
+ font-weight: 500;
71
+ fill: #64748b;
72
+ text-anchor: middle;
73
+ }
74
+
75
+ /* Utility icon styles */
76
+ .title-icon { margin-left: 0.6rem; font-size: 1.1rem; }
77
+ .action-icon { margin-right: 0.35rem; }
78
+ .status-icon { margin-right: 0.5rem; flex-shrink: 0; font-size: 1.1rem; }
79
+ .loading-icon { animation: spin 1.5s linear infinite; display: inline-block; }
80
+ @keyframes spin { from { transform: rotate(0deg); } to { transform: rotate(360deg); } }
81
+
82
+ /* Confusion Matrix styles */
83
+ .cm-cell {
84
+ display: flex;
85
+ align-items: center;
86
+ justify-content: center;
87
+ flex-direction: column;
88
+ line-height: 1.2;
89
+ padding: 0.5rem;
90
+ border-radius: 0.25rem;
91
+ transition: all 0.2s ease;
92
+ }
93
+ .cm-value { font-size: 1.125rem; font-weight: 800; }
94
+ .cm-label { font-size: 0.65rem; text-transform: uppercase; letter-spacing: 0.05em; font-weight: 600; }
95
+ </style>
96
+ </head>
97
+ <body class="text-slate-800">
98
+
99
+ <div class="container mx-auto p-2 sm:p-4 max-w-full">
100
+ <header class="mb-4 text-center">
101
+ <h1 class="text-3xl sm:text-4xl font-extrabold text-blue-600">MLP Interactive Visualization</h1>
102
+ <p class="text-md text-slate-600 mt-1">Build, train, and visualize a Multi-Layer Perceptron.</p>
103
+ </header>
104
+
105
+ <!-- Main layout with reduced gap -->
106
+ <div class="flex flex-col lg:flex-row gap-2.5">
107
+ <!-- Control Panel -->
108
+ <div class="lg:w-6/12 bg-white p-3.5 rounded-2xl shadow-xl border border-slate-200 space-y-5">
109
+ <!-- Data & Training Section -->
110
+ <div class="space-y-4">
111
+ <h2 class="text-xl font-bold text-slate-700 border-b border-slate-200 pb-2 flex items-center">
112
+ 1. Data & Training <span class="title-icon">⚙️</span>
113
+ </h2>
114
+ <div class="grid grid-cols-2 sm:grid-cols-4 gap-x-4 gap-y-3">
115
+ <div class="col-span-2 sm:col-span-4">
116
+ <label class="text-sm font-medium text-slate-700 block mb-1.5">Data Input Class:</label>
117
+ <div class="flex rounded-lg shadow-sm">
118
+ <button id="class0Button" class="flex-1 px-3 py-2 border border-slate-300 rounded-l-lg bg-white text-base font-semibold text-slate-700 hover:bg-slate-50 focus:z-10 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 transition-all duration-150" title="Select Class 0 for the next data points">Class 0 (Blue)</button>
119
+ <button id="class1Button" class="flex-1 px-3 py-2 border-t border-b border-r border-slate-300 rounded-r-lg bg-white text-base font-semibold text-slate-700 hover:bg-slate-50 focus:z-10 focus:outline-none focus:ring-2 focus:ring-red-500 focus:border-red-500 transition-all duration-150" title="Select Class 1 for the next data points">Class 1 (Red)</button>
120
+ </div>
121
+ </div>
122
+ <!-- Other controls... -->
123
+ <div>
124
+ <label for="datasetSelect" class="text-sm font-medium text-slate-700 block mb-1">Load Dataset:</label>
125
+ <select id="datasetSelect" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 bg-white rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="Choose a preset dataset.">
126
+ <option value="manual">Manual Input</option>
127
+ <option value="two_moons">Two Moons</option>
128
+ <option value="circles">Concentric Circles</option>
129
+ <option value="xor">XOR</option>
130
+ <option value="spiral">Spiral</option>
131
+ </select>
132
+ </div>
133
+ <div>
134
+ <label for="dataNoise" class="text-sm font-medium text-slate-700 block mb-1">Data Noise:</label>
135
+ <input type="number" id="dataNoise" value="0.05" min="0" max="0.5" step="0.01" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="Add random noise to the data points.">
136
+ </div>
137
+ <div>
138
+ <label for="learningRate" class="text-sm font-medium text-slate-700 block mb-1">Learning Rate:</label>
139
+ <input type="number" id="learningRate" value="0.01" step="0.001" min="0.00001" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="How fast the model learns.">
140
+ </div>
141
+ <div>
142
+ <label for="epochs" class="text-sm font-medium text-slate-700 block mb-1">Epochs:</label>
143
+ <input type="number" id="epochs" value="150" step="10" min="1" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="Number of training iterations.">
144
+ </div>
145
+ <div>
146
+ <label for="optimizerSelect" class="text-sm font-medium text-slate-700 block mb-1">Optimizer:</label>
147
+ <select id="optimizerSelect" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 bg-white rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="Algorithm to update weights.">
148
+ <option value="adam">Adam</option>
149
+ <option value="sgd">SGD</option>
150
+ <option value="rmsprop">RMSprop</option>
151
+ </select>
152
+ </div>
153
+ <div>
154
+ <label for="batchSize" class="text-sm font-medium text-slate-700 block mb-1">Batch Size:</label>
155
+ <input type="number" id="batchSize" value="16" min="0" step="4" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="Samples per weight update. 0=Full.">
156
+ </div>
157
+ <div>
158
+ <label for="regularizationTypeSelect" class="text-sm font-medium text-slate-700 block mb-1">Regularization:</label>
159
+ <select id="regularizationTypeSelect" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 bg-white rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="Technique to prevent overfitting.">
160
+ <option value="none">None</option>
161
+ <option value="l1">L1</option>
162
+ <option value="l2">L2</option>
163
+ </select>
164
+ </div>
165
+ <div id="regularizationRateContainer" class="hidden">
166
+ <label for="regularizationRateInput" class="text-sm font-medium text-slate-700 block mb-1">Reg. Rate (λ):</label>
167
+ <input type="number" id="regularizationRateInput" value="0.001" step="0.0001" min="0" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="Strength of the L1/L2 penalty.">
168
+ </div>
169
+ </div>
170
+ </div>
171
+
172
+ <!-- MLP Architecture Section -->
173
+ <div class="space-y-2">
174
+ <h2 class="text-xl font-bold text-slate-700 border-b border-slate-200 pb-2 flex items-center">
175
+ 2. MLP Architecture <span class="title-icon">🧠</span>
176
+ </h2>
177
+ <div id="networkVisualization" class="bg-slate-50 rounded-lg p-2 min-h-[120px] border border-slate-200"></div>
178
+ <h3 class="text-sm font-medium text-slate-700 pt-2">Hidden Layers:</h3>
179
+ <div id="hiddenLayersConfigContainer" class="grid grid-cols-1 md:grid-cols-2 gap-2 max-h-48 overflow-y-auto pr-1.5"></div>
180
+ <button id="addHiddenLayerButton" class="mt-1 w-full flex justify-center items-center py-2 px-3 border-2 border-dashed border-blue-400 rounded-lg shadow-sm text-sm font-semibold text-blue-600 bg-blue-50 hover:bg-blue-100 hover:border-blue-500 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-blue-500 transition-all duration-150" title="Add a new hidden layer.">
181
+ <span class="action-icon text-base">➕</span> Add Layer
182
+ </button>
183
+ </div>
184
+
185
+ <!-- Actions Section -->
186
+ <div class="space-y-2">
187
+ <h2 class="text-xl font-bold text-slate-700 border-b border-slate-200 pb-2 flex items-center">
188
+ 3. Actions <span class="title-icon">⚡️</span>
189
+ </h2>
190
+ <div class="grid grid-cols-2 sm:grid-cols-4 gap-2.5">
191
+ <button id="trainButton" class="col-span-2 flex justify-center items-center py-2.5 px-3 border border-transparent rounded-lg shadow-md text-base font-bold text-white bg-green-600 hover:bg-green-700 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-green-500 transition-all duration-150 disabled:opacity-50 disabled:cursor-not-allowed disabled:bg-green-400" title="Start or retrain the model.">
192
+ <span class="action-icon">▶️</span> <span id="trainButtonText">Train</span>
193
+ </button>
194
+ <button id="stopButton" class="hidden col-span-2 flex justify-center items-center py-2.5 px-3 border border-transparent rounded-lg shadow-md text-base font-bold text-white bg-red-600 hover:bg-red-700 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-red-500 transition-all" title="Stop the current training process.">
195
+ <span class="action-icon">⏹️</span> <span>Stop</span>
196
+ </button>
197
+ <button id="resetWeightsButton" class="flex justify-center items-center py-2.5 px-3 border border-slate-300 rounded-lg shadow-sm text-base font-semibold text-slate-700 bg-white hover:bg-slate-50 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-indigo-500 transition-all duration-150 disabled:opacity-50 disabled:cursor-not-allowed" title="Re-initialize the model's weights.">
198
+ <span class="action-icon">🔄</span> <span>Reset</span>
199
+ </button>
200
+ <button id="clearButton" class="flex justify-center items-center py-2.5 px-3 border border-slate-300 rounded-lg shadow-sm text-base font-semibold text-slate-700 bg-white hover:bg-slate-50 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-indigo-500 transition-all duration-150 disabled:opacity-50 disabled:cursor-not-allowed" title="Delete all data and reset settings.">
201
+ <span class="action-icon">🗑️</span> <span>Clear All</span>
202
+ </button>
203
+ </div>
204
+ </div>
205
+ </div>
206
+
207
+ <!-- Visualization Area with reduced gap -->
208
+ <div class="lg:w-6/12 flex flex-col gap-2.5">
209
+ <div class="bg-white p-2 rounded-2xl shadow-xl border border-slate-200">
210
+ <canvas id="plotCanvas" class="border border-slate-200 rounded-xl w-full"></canvas>
211
+ </div>
212
+
213
+ <div class="bg-white p-3.5 rounded-2xl shadow-xl border border-slate-200 flex flex-col flex-grow">
214
+ <h2 class="text-xl font-bold text-slate-700 mb-2 border-b border-slate-200 pb-2 flex items-center">
215
+ Training Status <span class="title-icon">📊</span>
216
+ </h2>
217
+ <div id="trainingParamsDisplay" class="text-xs mb-2 space-y-1"></div>
218
+ <div id="statusMessage" class="text-sm text-slate-600 p-2 bg-slate-50 rounded-lg min-h-[44px] mb-3 whitespace-pre-line leading-snug flex items-center justify-center border border-slate-200"></div>
219
+ <div class="flex-grow flex flex-col sm:flex-row gap-4">
220
+ <div class="w-full sm:w-2/3 relative min-h-[150px]">
221
+ <canvas id="trainingPlotCanvas" class="w-full h-full"></canvas>
222
+ </div>
223
+ <div id="confusionMatrixContainer" class="w-full sm:w-1/3">
224
+ <!-- Confusion Matrix will be rendered here -->
225
+ </div>
226
+ </div>
227
+ </div>
228
+ </div>
229
+ </div>
230
+ </div>
231
+
232
+ <script>
233
+ // --- App Namespace ---
234
+ // Encapsulate the entire application in a single object to avoid polluting the global namespace.
235
+ // This improves organization and prevents potential conflicts with other scripts.
236
+ const App = {
237
+ // --- STATE & CONFIG ---
238
+ state: {
239
+ model: null,
240
+ dataPoints: [],
241
+ currentClass: 0,
242
+ hiddenLayerConfigs: [],
243
+ trainingHistory: { loss: [], acc: [] },
244
+ isTraining: false,
245
+ },
246
+ ui: {}, // To hold DOM element references
247
+ config: {
248
+ pointRadius: 4.5,
249
+ classColors: {
250
+ 0: { point: 'rgba(59, 130, 246, 1)', boundary: 'rgba(59, 130, 246, 0.3)' },
251
+ 1: { point: 'rgba(239, 68, 68, 1)', boundary: 'rgba(239, 68, 68, 0.3)' }
252
+ },
253
+ statusIcons: {
254
+ info: 'ℹ️', success: '✅', warning: '⚠️', error: '❌', loading: '⏳'
255
+ },
256
+ },
257
+
258
+ // --- INITIALIZATION ---
259
+ async init() {
260
+ this.cacheUIElements();
261
+ this.registerEventListeners();
262
+
263
+ await tf.ready();
264
+ try {
265
+ await tf.setBackend('cpu');
266
+ console.log("TensorFlow.js backend set to CPU.");
267
+ } catch (e) { console.warn("Could not set TF.js backend to CPU.", e); }
268
+
269
+ this.methods.resizePlotCanvas();
270
+ this.initializeApplicationState();
271
+ },
272
+
273
+ cacheUIElements() {
274
+ const ids = [
275
+ 'class0Button', 'class1Button', 'datasetSelect', 'dataNoise', 'batchSize',
276
+ 'hiddenLayersConfigContainer', 'addHiddenLayerButton', 'networkVisualization',
277
+ 'optimizerSelect', 'learningRate', 'epochs', 'regularizationTypeSelect',
278
+ 'regularizationRateContainer', 'regularizationRateInput', 'trainButton',
279
+ 'stopButton', 'resetWeightsButton', 'clearButton', 'statusMessage',
280
+ 'plotCanvas', 'trainingPlotCanvas', 'trainingParamsDisplay', 'trainButtonText',
281
+ 'confusionMatrixContainer'
282
+ ];
283
+ ids.forEach(id => { this.ui[id] = document.getElementById(id); });
284
+ this.ui.canvas = this.ui.plotCanvas;
285
+ this.ui.ctx = this.ui.canvas.getContext('2d');
286
+ },
287
+
288
+ registerEventListeners() {
289
+ window.addEventListener('resize', () => {
290
+ this.methods.resizePlotCanvas();
291
+ this.methods.drawTrainingPlot();
292
+ this.methods.drawNetworkVisualization();
293
+ });
294
+ this.ui.canvas.addEventListener('click', (e) => this.methods.handleCanvasClick(e));
295
+ this.ui.class0Button.addEventListener('click', () => this.methods.setActiveClass(0));
296
+ this.ui.class1Button.addEventListener('click', () => this.methods.setActiveClass(1));
297
+ this.ui.addHiddenLayerButton.addEventListener('click', () => this.methods.addHiddenLayerUI());
298
+ this.ui.trainButton.addEventListener('click', () => this.methods.trainAndVisualize());
299
+ this.ui.stopButton.addEventListener('click', () => this.methods.stopTraining());
300
+ this.ui.resetWeightsButton.addEventListener('click', () => this.methods.resetModelWeights());
301
+ this.ui.clearButton.addEventListener('click', () => this.methods.clearAllDataAndReseed());
302
+ this.ui.datasetSelect.addEventListener('change', () => this.methods.loadSelectedDataset());
303
+ this.ui.dataNoise.addEventListener('input', () => { if (this.ui.datasetSelect.value !== 'manual') this.methods.loadSelectedDataset(); });
304
+ this.ui.regularizationTypeSelect.addEventListener('change', () => this.methods.toggleRegularizationRateInput());
305
+ },
306
+
307
+ initializeApplicationState() {
308
+ this.methods.setActiveClass(0);
309
+ this.methods.addHiddenLayerUI(8, 'relu'); // Default architecture
310
+ this.methods.addHiddenLayerUI(4, 'relu');
311
+ this.methods.toggleRegularizationRateInput();
312
+ this.methods.drawAll();
313
+ this.methods.drawTrainingPlot();
314
+ this.methods.updateTrainingParamsDisplay();
315
+ this.methods.updateButtonStates();
316
+ this.methods.renderConfusionMatrix(null);
317
+ this.methods.updateStatus('Ready. Click canvas to add points or load a dataset.', 'info');
318
+ },
319
+
320
+ // --- METHODS (Logic & Handlers) ---
321
+ methods: {
322
+ // --- UI & State Management ---
323
+ resizePlotCanvas() {
324
+ const dpr = window.devicePixelRatio || 1;
325
+ const rect = App.ui.canvas.parentElement.getBoundingClientRect();
326
+ if (rect.width === 0) return;
327
+
328
+ App.ui.canvas.width = rect.width * dpr;
329
+ const newHeight = Math.min(rect.width * 0.85, Math.max(300, window.innerHeight * 0.5));
330
+ App.ui.canvas.height = newHeight * dpr;
331
+
332
+ App.ui.ctx.scale(dpr, dpr);
333
+ App.ui.canvas.style.width = `${rect.width}px`;
334
+ App.ui.canvas.style.height = `${newHeight}px`;
335
+
336
+ App.ui.canvasWidth = rect.width;
337
+ App.ui.canvasHeight = newHeight;
338
+ this.drawAll();
339
+ },
340
+
341
+ setActiveClass(classNum) {
342
+ App.state.currentClass = classNum;
343
+ App.ui.class0Button.classList.toggle('button-active-blue', classNum === 0);
344
+ App.ui.class1Button.classList.toggle('button-active-red', classNum === 1);
345
+ App.ui.plotCanvas.className = App.ui.plotCanvas.className.replace(/cursor-class-\d/, '');
346
+ App.ui.plotCanvas.classList.add(`cursor-class-${classNum}`);
347
+ },
348
+
349
+ toggleRegularizationRateInput() {
350
+ App.ui.regularizationRateContainer.classList.toggle('hidden', App.ui.regularizationTypeSelect.value === 'none');
351
+ },
352
+
353
+ updateStatus(message, type = 'info') {
354
+ const icon = App.config.statusIcons[type] || '';
355
+ const loadingClass = type === 'loading' ? 'loading-icon' : '';
356
+ const iconHtml = `<span class="status-icon ${loadingClass}">${icon}</span>`;
357
+ App.ui.statusMessage.innerHTML = `${iconHtml}<span>${message}</span>`;
358
+
359
+ const typeToColor = {
360
+ info: 'text-slate-600 bg-slate-50 border-slate-200',
361
+ success: 'text-green-700 bg-green-50 border-green-200 font-semibold',
362
+ warning: 'text-amber-700 bg-amber-50 border-amber-200 font-semibold',
363
+ error: 'text-red-700 bg-red-50 border-red-200 font-semibold',
364
+ loading: 'text-blue-700 bg-blue-50 border-blue-200'
365
+ };
366
+ App.ui.statusMessage.className = `text-sm p-2 rounded-lg min-h-[44px] whitespace-pre-line leading-snug flex items-center justify-center border ${typeToColor[type] || ''}`;
367
+ },
368
+
369
+ setTrainingState(training) {
370
+ App.state.isTraining = training;
371
+ App.ui.trainButton.classList.toggle('hidden', training);
372
+ App.ui.stopButton.classList.toggle('hidden', !training);
373
+ this.updateButtonStates();
374
+ },
375
+
376
+ updateButtonStates() {
377
+ const hasData = App.state.dataPoints.length > 0;
378
+ const hasModel = App.state.model != null;
379
+ App.ui.trainButton.disabled = App.state.isTraining || !hasData;
380
+ App.ui.clearButton.disabled = App.state.isTraining;
381
+ App.ui.resetWeightsButton.disabled = App.state.isTraining || !hasModel;
382
+
383
+ if (hasModel && App.state.trainingHistory.loss.length > 0 && !App.state.isTraining) {
384
+ App.ui.trainButtonText.textContent = 'Retrain';
385
+ } else {
386
+ App.ui.trainButtonText.textContent = 'Train';
387
+ }
388
+ },
389
+
390
+ // --- Architecture UI ---
391
+ addHiddenLayerUI(defaultNeurons = 8, defaultActivation = 'relu') {
392
+ if (App.state.hiddenLayerConfigs.length >= 8) {
393
+ this.updateStatus("Max 8 hidden layers reached.", 'warning');
394
+ return;
395
+ }
396
+ const layerIndex = App.state.hiddenLayerConfigs.length;
397
+ App.state.hiddenLayerConfigs.push({ neurons: parseInt(defaultNeurons), activation: defaultActivation });
398
+
399
+ const layerDiv = document.createElement('div');
400
+ layerDiv.className = 'layer-config-item bg-slate-50 border border-slate-200 p-2 rounded-lg flex items-center justify-between gap-2 hover:bg-slate-100 transition-colors';
401
+ layerDiv.dataset.index = layerIndex;
402
+
403
+ layerDiv.innerHTML = `
404
+ <div class="flex-grow flex items-center gap-x-3">
405
+ <div class="flex items-center">
406
+ <label for="neurons_${layerIndex}" class="text-sm font-medium text-slate-600 mr-2 whitespace-nowrap">L${layerIndex + 1} Units:</label>
407
+ <input type="number" id="neurons_${layerIndex}" value="${defaultNeurons}" min="1" max="64" step="1" class="w-16 px-2 py-1 border border-slate-300 rounded-md shadow-sm text-sm focus:ring-blue-500 focus:border-blue-500">
408
+ </div>
409
+ <div class="flex items-center">
410
+ <label for="activation_${layerIndex}" class="text-sm font-medium text-slate-600 mr-2">Act:</label>
411
+ <select id="activation_${layerIndex}" class="flex-1 min-w-[90px] px-2 py-1 border bg-white border-slate-300 rounded-md shadow-sm text-sm focus:ring-blue-500 focus:border-blue-500">
412
+ ${['relu', 'sigmoid', 'tanh', 'leakyRelu'].map(act => `<option value="${act}" ${act === defaultActivation ? 'selected' : ''}>${act.charAt(0).toUpperCase() + act.slice(1)}</option>`).join('')}
413
+ </select>
414
+ </div>
415
+ </div>
416
+ <button title="Remove layer" class="remove-btn flex-shrink-0 p-1.5 text-red-500 hover:text-red-700 rounded-full hover:bg-red-100 focus:outline-none focus:ring-2 focus:ring-red-500 focus:ring-offset-1 transition-all duration-150">
417
+ <span class="text-base font-bold">⛔</span>
418
+ </button>
419
+ `;
420
+
421
+ App.ui.hiddenLayersConfigContainer.appendChild(layerDiv);
422
+
423
+ layerDiv.querySelector(`#neurons_${layerIndex}`).onchange = (e) => {
424
+ const value = parseInt(e.target.value);
425
+ App.state.hiddenLayerConfigs[layerIndex].neurons = !isNaN(value) && value > 0 ? value : 1;
426
+ e.target.value = App.state.hiddenLayerConfigs[layerIndex].neurons;
427
+ this.drawNetworkVisualization();
428
+ };
429
+ layerDiv.querySelector(`#activation_${layerIndex}`).onchange = (e) => {
430
+ App.state.hiddenLayerConfigs[layerIndex].activation = e.target.value;
431
+ this.drawNetworkVisualization();
432
+ };
433
+ layerDiv.querySelector('.remove-btn').onclick = () => {
434
+ App.state.hiddenLayerConfigs.splice(layerIndex, 1);
435
+ this.redrawLayerConfigs();
436
+ };
437
+ this.drawNetworkVisualization();
438
+ },
439
+
440
+ redrawLayerConfigs() {
441
+ const configs = [...App.state.hiddenLayerConfigs];
442
+ App.ui.hiddenLayersConfigContainer.innerHTML = '';
443
+ App.state.hiddenLayerConfigs = [];
444
+ configs.forEach(config => this.addHiddenLayerUI(config.neurons, config.activation));
445
+ },
446
+
447
+ // --- Data Handling ---
448
+ handleCanvasClick(event) {
449
+ if (!App.ui.canvas) return;
450
+ const rect = App.ui.canvas.getBoundingClientRect();
451
+ const x = event.clientX - rect.left;
452
+ const y = event.clientY - rect.top;
453
+
454
+ App.state.dataPoints.push({ x, y, normX: x / App.ui.canvasWidth, normY: y / App.ui.canvasHeight, label: App.state.currentClass });
455
+ this.drawPoints();
456
+ this.updateStatus(`Added Class ${App.state.currentClass} point. Total: ${App.state.dataPoints.length}.`, 'info');
457
+ App.ui.datasetSelect.value = "manual";
458
+ this.updateButtonStates();
459
+ },
460
+
461
+ loadSelectedDataset() {
462
+ if (App.state.model) { App.state.model.dispose(); App.state.model = null; }
463
+ tf.disposeVariables();
464
+ App.state.dataPoints = [];
465
+ const datasetName = App.ui.datasetSelect.value;
466
+ if (datasetName === 'manual') {
467
+ this.drawAll();
468
+ this.updateButtonStates();
469
+ return;
470
+ }
471
+ const noise = parseFloat(App.ui.dataNoise.value) || 0;
472
+ const nSamples = 150;
473
+
474
+ const generators = {
475
+ two_moons: (n, noise) => {
476
+ const n_per_moon = Math.floor(n / 2);
477
+ const radius = 0.3;
478
+ for (let i = 0; i < n_per_moon; i++) {
479
+ const angle = (i / n_per_moon) * Math.PI;
480
+ // First moon, shifted left and up
481
+ this.addDataPoint(
482
+ 0.5 + radius * Math.cos(angle) - 0.125,
483
+ 0.5 + radius * Math.sin(angle) + 0.1,
484
+ 0, noise
485
+ );
486
+ // Second moon, shifted right and down
487
+ this.addDataPoint(
488
+ 0.5 + radius * Math.cos(angle + Math.PI) + 0.125,
489
+ 0.5 + radius * Math.sin(angle + Math.PI) - 0.1,
490
+ 1, noise
491
+ );
492
+ }
493
+ },
494
+ circles: (n, noise) => { for (let i=0; i<n; i++) { const r=Math.random(), a=Math.random()*2*Math.PI, l=r<0.5?0:1, rs=l===0?0.2:0.4; this.addDataPoint(0.5+rs*Math.cos(a), 0.5+rs*Math.sin(a), l, noise); } },
495
+ xor: (n, noise) => { const q=Math.floor(n/4), s=0.3; for (let i=0; i<q; i++) { this.addDataPoint(0.5-s, 0.5-s, 0, noise*2); this.addDataPoint(0.5+s, 0.5-s, 1, noise*2); this.addDataPoint(0.5-s, 0.5+s, 1, noise*2); this.addDataPoint(0.5+s, 0.5+s, 0, noise*2); } },
496
+ spiral: (n, noise) => { const pts=Math.floor(n/2); for (let i=0; i<pts; i++) { const a=i/20*Math.PI, r=0.05+i/pts*0.4; this.addDataPoint(0.5+r*Math.cos(a), 0.5+r*Math.sin(a), 0, noise*0.5); this.addDataPoint(0.5+r*Math.cos(a+Math.PI), 0.5+r*Math.sin(a+Math.PI), 1, noise*0.5); } }
497
+ };
498
+
499
+ generators[datasetName](nSamples, noise);
500
+ App.state.dataPoints.forEach(p => { p.x = p.normX * App.ui.canvasWidth; p.y = p.normY * App.ui.canvasHeight; });
501
+ this.drawAll();
502
+ this.updateStatus(`Loaded '${datasetName}' dataset. Noise: ${noise}.`, 'info');
503
+
504
+ App.state.trainingHistory = { loss: [], acc: [] };
505
+ this.drawTrainingPlot();
506
+ this.renderConfusionMatrix(null);
507
+ this.updateButtonStates();
508
+ this.updateTrainingParamsDisplay();
509
+ },
510
+
511
+ addDataPoint(normX, normY, label, noise) {
512
+ App.state.dataPoints.push({
513
+ x: 0, y: 0,
514
+ normX: normX + (Math.random() - 0.5) * noise,
515
+ normY: normY + (Math.random() - 0.5) * noise,
516
+ label: label
517
+ });
518
+ },
519
+
520
+ // --- Drawing & Visualization ---
521
+ drawAll(boundaryGrid = null) {
522
+ if (!App.ui.ctx || !App.ui.canvasWidth || !App.ui.canvasHeight) return;
523
+ App.ui.ctx.clearRect(0, 0, App.ui.canvasWidth, App.ui.canvasHeight);
524
+ if (boundaryGrid) {
525
+ const { grid, resolution } = boundaryGrid;
526
+ for (let i = 0; i < grid.length; i++) {
527
+ for (let j = 0; j < grid[i].length; j++) {
528
+ App.ui.ctx.fillStyle = App.config.classColors[grid[i][j]].boundary;
529
+ App.ui.ctx.fillRect(j * resolution, i * resolution, resolution, resolution);
530
+ }
531
+ }
532
+ }
533
+ App.state.dataPoints.forEach(point => {
534
+ App.ui.ctx.beginPath();
535
+ App.ui.ctx.arc(point.x, point.y, App.config.pointRadius, 0, 2 * Math.PI);
536
+ App.ui.ctx.fillStyle = App.config.classColors[point.label].point;
537
+ App.ui.ctx.fill();
538
+ App.ui.ctx.strokeStyle = 'rgba(255,255,255,0.7)';
539
+ App.ui.ctx.lineWidth = 1.5;
540
+ App.ui.ctx.stroke();
541
+ });
542
+ },
543
+
544
+ drawPoints() { this.drawAll(); },
545
+
546
+ drawNetworkVisualization() {
547
+ const container = App.ui.networkVisualization;
548
+ container.innerHTML = '';
549
+ const allLayers = [
550
+ { type: 'input', neurons: 2, activation: 'Input' },
551
+ ...App.state.hiddenLayerConfigs.map((cfg, i) => ({ type: `hidden-${i % 5}`, neurons: cfg.neurons, activation: cfg.activation })),
552
+ { type: 'output', neurons: 1, activation: 'Sigmoid' }
553
+ ];
554
+
555
+ const svg = document.createElementNS("http://www.w3.org/2000/svg", "svg");
556
+ const rect = container.getBoundingClientRect();
557
+ if(rect.width === 0 || rect.height === 0) return;
558
+ svg.setAttribute('viewBox', `0 0 ${rect.width} ${rect.height}`);
559
+
560
+ const margin = { top: 25, right: 15, bottom: 20, left: 15 };
561
+ const width = rect.width - margin.left - margin.right;
562
+ const height = rect.height - margin.top - margin.bottom;
563
+ const layerSpacing = allLayers.length > 1 ? width / (allLayers.length - 1) : width;
564
+
565
+ // Connections
566
+ for (let i = 0; i < allLayers.length - 1; i++) {
567
+ const x1 = margin.left + i * layerSpacing;
568
+ const x2 = margin.left + (i + 1) * layerSpacing;
569
+ const maxNeurons = 8;
570
+ const currentNeurons = Math.min(allLayers[i].neurons, maxNeurons);
571
+ const nextNeurons = Math.min(allLayers[i+1].neurons, maxNeurons);
572
+ for (let j = 0; j < currentNeurons; j++) {
573
+ const y1 = margin.top + height * ((j + 0.5) / currentNeurons);
574
+ for (let k = 0; k < nextNeurons; k++) {
575
+ const y2 = margin.top + height * ((k + 0.5) / nextNeurons);
576
+ const line = document.createElementNS("http://www.w3.org/2000/svg", "line");
577
+ line.setAttribute('x1', x1); line.setAttribute('y1', y1);
578
+ line.setAttribute('x2', x2); line.setAttribute('y2', y2);
579
+ line.setAttribute('class', 'connection');
580
+ svg.appendChild(line);
581
+ }
582
+ }
583
+ }
584
+ // Neurons and Labels
585
+ allLayers.forEach((layer, i) => {
586
+ const x = margin.left + i * layerSpacing;
587
+ const maxNeurons = 8;
588
+ const displayNeurons = Math.min(layer.neurons, maxNeurons);
589
+ const neuronRadius = Math.max(3, Math.min(7, height / (displayNeurons * 2.5)));
590
+ for (let j = 0; j < displayNeurons; j++) {
591
+ const y = margin.top + height * ((j + 0.5) / displayNeurons);
592
+ const circle = document.createElementNS("http://www.w3.org/2000/svg", "circle");
593
+ circle.setAttribute('cx', x); circle.setAttribute('cy', y);
594
+ circle.setAttribute('r', neuronRadius);
595
+ circle.setAttribute('class', `neuron ${layer.type}`);
596
+ svg.appendChild(circle);
597
+ }
598
+ const labelText = layer.activation.charAt(0).toUpperCase() + layer.activation.slice(1);
599
+ const textLabel = document.createElementNS("http://www.w3.org/2000/svg", "text");
600
+ textLabel.setAttribute('x', x); textLabel.setAttribute('y', margin.top - 8);
601
+ textLabel.setAttribute('class', 'layer-label');
602
+ textLabel.textContent = labelText;
603
+ svg.appendChild(textLabel);
604
+
605
+ const countLabel = document.createElementNS("http://www.w3.org/2000/svg", "text");
606
+ countLabel.setAttribute('x', x); countLabel.setAttribute('y', margin.top + height + 15);
607
+ countLabel.setAttribute('class', 'neuron-count-label');
608
+ countLabel.textContent = `${layer.neurons} N`;
609
+ svg.appendChild(countLabel);
610
+ });
611
+ container.appendChild(svg);
612
+ },
613
+
614
+ drawTrainingPlot() {
615
+ const canvas = App.ui.trainingPlotCanvas;
616
+ const ctx = canvas.getContext('2d');
617
+ const dpr = window.devicePixelRatio || 1;
618
+ const rect = canvas.getBoundingClientRect();
619
+ if (rect.width === 0 || rect.height === 0) return;
620
+ canvas.width = rect.width * dpr;
621
+ canvas.height = rect.height * dpr;
622
+ ctx.scale(dpr, dpr);
623
+ const { width, height } = rect;
624
+ const padding = {top: 20, right: 15, bottom: 20, left: 30};
625
+
626
+ ctx.fillStyle = '#f8fafc';
627
+ ctx.fillRect(0,0,width,height);
628
+
629
+ if (App.state.trainingHistory.loss.length === 0) {
630
+ ctx.fillStyle = '#64748b';
631
+ ctx.textAlign = 'center';
632
+ ctx.font = '12px Inter';
633
+ ctx.fillText('Training history will be plotted here.', width / 2, height / 2);
634
+ return;
635
+ }
636
+
637
+ ctx.beginPath();
638
+ ctx.strokeStyle = '#e2e8f0';
639
+ ctx.lineWidth = 1;
640
+ for(let i = 0; i <= 4; i++){
641
+ const y = padding.top + i * (height - padding.top - padding.bottom) / 4;
642
+ ctx.moveTo(padding.left, y);
643
+ ctx.lineTo(width-padding.right, y);
644
+ }
645
+ ctx.stroke();
646
+
647
+ ctx.font = '10px Inter';
648
+ ctx.fillStyle = '#475569';
649
+ ctx.textAlign = 'right';
650
+ for(let i = 0; i <= 4; i++){
651
+ ctx.fillText((1 - i/4).toFixed(1), padding.left - 6, padding.top + 3 + i * (height - padding.top - padding.bottom) / 4);
652
+ }
653
+
654
+ const plotData = (data, color) => {
655
+ ctx.beginPath(); ctx.strokeStyle = color; ctx.lineWidth = 2; ctx.lineJoin = 'round'; ctx.lineCap = 'round';
656
+ data.forEach((val, i) => {
657
+ const x = padding.left + (i / (Math.max(1, data.length -1))) * (width - padding.left - padding.right);
658
+ const y = (height - padding.bottom) - Math.min(Math.max(val,0.0),1.0) * (height - padding.top - padding.bottom);
659
+ if (i === 0) ctx.moveTo(x, y); else ctx.lineTo(x, y);
660
+ });
661
+ ctx.stroke();
662
+ };
663
+ plotData(App.state.trainingHistory.loss, 'rgba(239, 68, 68, 0.9)');
664
+ plotData(App.state.trainingHistory.acc, 'rgba(37, 99, 235, 0.9)');
665
+
666
+ ctx.textAlign = 'left';
667
+ ctx.font = '600 11px Inter';
668
+ ctx.fillStyle = 'rgba(239, 68, 68, 1)'; ctx.fillRect(padding.left + 5, 5, 10, 3);
669
+ ctx.fillStyle = '#374151'; ctx.fillText('Loss', padding.left + 20, 10);
670
+ ctx.fillStyle = 'rgba(37, 99, 235, 1)'; ctx.fillRect(padding.left + 75, 5, 10, 3);
671
+ ctx.fillStyle = '#374151'; ctx.fillText('Accuracy', padding.left + 90, 10);
672
+ },
673
+
674
+ updateTrainingParamsDisplay() {
675
+ const container = App.ui.trainingParamsDisplay;
676
+ if (!App.state.model) {
677
+ container.innerHTML = `<div class="text-slate-500">Train a model to see parameters.</div>`;
678
+ return;
679
+ }
680
+ const lr = parseFloat(App.ui.learningRate.value);
681
+ const regType = App.ui.regularizationTypeSelect.value;
682
+ const regRate = parseFloat(App.ui.regularizationRateInput.value) || 0;
683
+ let regDesc = regType !== 'none' ? `${regType.toUpperCase()}(λ=${regRate})` : 'None';
684
+ let hiddenDesc = App.state.hiddenLayerConfigs.map(l => l.neurons).join(' → ');
685
+
686
+ container.innerHTML = `
687
+ <div class="flex flex-wrap gap-x-4 gap-y-1">
688
+ <span><span class="font-semibold text-slate-500">Opt:</span> <span class="font-medium text-slate-800">${App.ui.optimizerSelect.value.toUpperCase()}</span></span>
689
+ <span><span class="font-semibold text-slate-500">LR:</span> <span class="font-medium text-slate-800">${lr.toExponential(1)}</span></span>
690
+ <span><span class="font-semibold text-slate-500">Batch:</span> <span class="font-medium text-slate-800">${parseInt(App.ui.batchSize.value) === 0 ? 'Full' : App.ui.batchSize.value}</span></span>
691
+ <span><span class="font-semibold text-slate-500">Reg:</span> <span class="font-medium text-slate-800">${regDesc}</span></span>
692
+ </div>
693
+ <div><span class="font-semibold text-slate-500">Layers:</span> <span class="font-medium text-slate-800">2 → ${hiddenDesc || '...'} → 1</span></div>
694
+ `;
695
+ },
696
+
697
+ renderConfusionMatrix(matrix) {
698
+ const container = App.ui.confusionMatrixContainer;
699
+ if (!matrix) {
700
+ container.innerHTML = `<div class="flex items-center justify-center h-full text-sm text-center text-slate-400 bg-slate-50 rounded-lg p-2 border border-slate-200">Confusion matrix appears here after training.</div>`;
701
+ return;
702
+ }
703
+ container.innerHTML = `
704
+ <div class="h-full flex flex-col">
705
+ <h4 class="text-sm font-semibold text-center text-slate-600 mb-1.5">Confusion Matrix</h4>
706
+ <div class="grid grid-cols-2 grid-rows-2 gap-1.5 flex-grow">
707
+ <div class="cm-cell bg-blue-100 text-blue-800" title="True Negative"><span class="cm-value">${matrix.tn}</span><span class="cm-label">TN</span></div>
708
+ <div class="cm-cell bg-red-100 text-red-800" title="False Positive"><span class="cm-value">${matrix.fp}</span><span class="cm-label">FP</span></div>
709
+ <div class="cm-cell bg-red-100 text-red-800" title="False Negative"><span class="cm-value">${matrix.fn}</span><span class="cm-label">FN</span></div>
710
+ <div class="cm-cell bg-green-100 text-green-800" title="True Positive"><span class="cm-value">${matrix.tp}</span><span class="cm-label">TP</span></div>
711
+ </div>
712
+ </div>
713
+ `;
714
+ },
715
+
716
+ // --- TENSORFLOW.JS & ML ---
717
+ getSafeNumericInput(element, defaultValue, isInteger = true) {
718
+ let value = isInteger ? parseInt(element.value, 10) : parseFloat(element.value);
719
+ if (isNaN(value)) {
720
+ value = defaultValue;
721
+ element.value = defaultValue;
722
+ }
723
+ return value;
724
+ },
725
+
726
+ buildModel() {
727
+ if (App.state.model) { App.state.model.dispose(); App.state.model = null; }
728
+ tf.disposeVariables();
729
+
730
+ const learningRate = this.getSafeNumericInput(App.ui.learningRate, 0.01, false);
731
+ const regType = App.ui.regularizationTypeSelect.value;
732
+ const regRate = this.getSafeNumericInput(App.ui.regularizationRateInput, 0, false);
733
+
734
+ const kernelRegularizer = (regType !== 'none' && regRate > 0)
735
+ ? tf.regularizers[regType]({[regType]: regRate})
736
+ : null;
737
+
738
+ App.state.model = tf.sequential();
739
+ const inputShape = [2];
740
+ // Add hidden layers
741
+ App.state.hiddenLayerConfigs.forEach((layerConfig, index) => {
742
+ App.state.model.add(tf.layers.dense({
743
+ units: layerConfig.neurons,
744
+ inputShape: index === 0 ? inputShape : undefined,
745
+ activation: layerConfig.activation,
746
+ kernelRegularizer
747
+ }));
748
+ });
749
+ // Add output layer
750
+ App.state.model.add(tf.layers.dense({
751
+ units: 1,
752
+ activation: 'sigmoid',
753
+ inputShape: App.state.hiddenLayerConfigs.length === 0 ? inputShape : undefined,
754
+ }));
755
+
756
+ const optimizerInstance = tf.train[App.ui.optimizerSelect.value](learningRate);
757
+ App.state.model.compile({ optimizer: optimizerInstance, loss: 'binaryCrossentropy', metrics: ['accuracy'] });
758
+ App.state.model.stopTraining = false;
759
+ },
760
+
761
+ async trainAndVisualize() {
762
+ if (App.state.isTraining) return;
763
+ const uniqueLabels = new Set(App.state.dataPoints.map(p => p.label));
764
+ if (App.state.dataPoints.length < 4 || uniqueLabels.size < 2) {
765
+ this.updateStatus('Requires at least 4 points and data from both classes.', 'error');
766
+ return;
767
+ }
768
+
769
+ this.setTrainingState(true);
770
+ this.updateStatus('Starting training...', 'loading');
771
+ this.renderConfusionMatrix(null);
772
+ await tf.nextFrame();
773
+
774
+ this.buildModel();
775
+ this.updateTrainingParamsDisplay();
776
+
777
+ const epochs = this.getSafeNumericInput(App.ui.epochs, 150, true);
778
+ let batchSize = this.getSafeNumericInput(App.ui.batchSize, 16, true);
779
+ if (batchSize === 0) batchSize = App.state.dataPoints.length;
780
+
781
+ const [xs, ys] = tf.tidy(() => {
782
+ const normalized = App.state.dataPoints.map(p => [p.normX, p.normY]);
783
+ const labels = App.state.dataPoints.map(p => p.label);
784
+ return [tf.tensor2d(normalized), tf.tensor2d(labels, [labels.length, 1])];
785
+ });
786
+
787
+ App.state.trainingHistory = { loss: [], acc: [] };
788
+ this.drawTrainingPlot();
789
+
790
+ try {
791
+ await App.state.model.fit(xs, ys, {
792
+ epochs, batchSize,
793
+ callbacks: {
794
+ onEpochEnd: async (epoch, logs) => {
795
+ if (App.state.model.stopTraining) { App.state.model.stop(); return; }
796
+ this.updateStatus(`Epoch ${epoch + 1}/${epochs} - Loss: ${logs.loss.toFixed(4)}, Acc: ${logs.acc.toFixed(4)}`, 'loading');
797
+ App.state.trainingHistory.loss.push(logs.loss);
798
+ App.state.trainingHistory.acc.push(logs.acc);
799
+ this.drawTrainingPlot();
800
+
801
+ if ((epoch + 1) % Math.max(1, Math.floor(epochs / 25)) === 0) {
802
+ const boundaryGrid = await this.generateBoundaryGrid();
803
+ this.drawAll(boundaryGrid);
804
+ }
805
+ await tf.nextFrame();
806
+ },
807
+ onTrainEnd: async () => {
808
+ const finalAcc = App.state.trainingHistory.acc.slice(-1)[0] || 0;
809
+ if (!App.state.model.stopTraining) {
810
+ this.updateStatus(`Training complete! Final Accuracy: <b>${(finalAcc*100).toFixed(2)}%</b>`, 'success');
811
+ }
812
+ const boundaryGrid = await this.generateBoundaryGrid();
813
+ this.drawAll(boundaryGrid);
814
+
815
+ const confusionMatrix = await this.calculateConfusionMatrix(xs, ys);
816
+ this.renderConfusionMatrix(confusionMatrix);
817
+ this.setTrainingState(false);
818
+ }
819
+ }
820
+ });
821
+ } catch (error) {
822
+ console.error("Training error:", error);
823
+ this.updateStatus(`Training error: ${error.message}`, 'error');
824
+ this.setTrainingState(false);
825
+ } finally {
826
+ xs.dispose();
827
+ ys.dispose();
828
+ }
829
+ },
830
+
831
+ async calculateConfusionMatrix(xs, ys) {
832
+ if (!App.state.model) return null;
833
+ return tf.tidy(() => {
834
+ const predictions = App.state.model.predict(xs).round();
835
+ const tp = ys.mul(predictions).sum().dataSync()[0];
836
+ const tn = ys.sub(1).mul(-1).mul(predictions.sub(1).mul(-1)).sum().dataSync()[0];
837
+ const fp = predictions.sub(ys).relu().sum().dataSync()[0];
838
+ const fn = ys.sub(predictions).relu().sum().dataSync()[0];
839
+ return { tp, tn, fp, fn };
840
+ });
841
+ },
842
+
843
+ async generateBoundaryGrid() {
844
+ if (!App.state.model || !App.ui.canvasWidth || !App.ui.canvasHeight) return null;
845
+ const resolution = Math.max(5, Math.floor(App.ui.canvasWidth / 80));
846
+ const numCols = Math.floor(App.ui.canvasWidth / resolution);
847
+ const numRows = Math.floor(App.ui.canvasHeight / resolution);
848
+ if (numCols <= 0 || numRows <= 0) return null;
849
+
850
+ const boundaryData = tf.tidy(() => {
851
+ const gridPoints = [];
852
+ for (let i = 0; i < numRows; i++) {
853
+ for (let j = 0; j < numCols; j++) {
854
+ gridPoints.push([(j * resolution) / App.ui.canvasWidth, (i * resolution) / App.ui.canvasHeight]);
855
+ }
856
+ }
857
+ const predsTensor = App.state.model.predict(tf.tensor2d(gridPoints));
858
+ return predsTensor.dataSync(); // Use dataSync inside tidy
859
+ });
860
+
861
+ const grid = []; let k = 0;
862
+ for (let i = 0; i < numRows; i++) {
863
+ const row = [];
864
+ for (let j = 0; j < numCols; j++) row.push(boundaryData[k++] > 0.5 ? 1 : 0);
865
+ grid.push(row);
866
+ }
867
+ return { grid, resolution };
868
+ },
869
+
870
+ stopTraining() {
871
+ if (App.state.model) {
872
+ App.state.model.stopTraining = true;
873
+ this.updateStatus("Training stopped by user.", 'warning');
874
+ }
875
+ },
876
+
877
+ resetModelWeights() {
878
+ if (App.state.isTraining) return;
879
+ this.updateStatus("Model weights reset. Ready to train again.", 'info');
880
+ this.buildModel();
881
+ this.drawAll();
882
+ App.state.trainingHistory = { loss: [], acc: [] };
883
+ this.drawTrainingPlot();
884
+ this.renderConfusionMatrix(null);
885
+ this.updateButtonStates();
886
+ },
887
+
888
+ clearAllDataAndReseed() {
889
+ if (App.state.isTraining) return;
890
+ App.state.dataPoints = [];
891
+ if(App.state.model) { App.state.model.dispose(); App.state.model = null; }
892
+ tf.disposeVariables();
893
+ App.state.hiddenLayerConfigs = [];
894
+ App.ui.hiddenLayersConfigContainer.innerHTML = '';
895
+
896
+ App.initializeApplicationState();
897
+ this.updateStatus('Cleared all data and reset configuration.', 'info');
898
+ }
899
+ }
900
+ };
901
+
902
+ // --- Entry Point ---
903
+ document.addEventListener('DOMContentLoaded', () => App.init());
904
+ </script>
905
+ </body>
906
  </html>