Update index.html
Browse files- index.html +905 -18
index.html
CHANGED
@@ -1,19 +1,906 @@
|
|
1 |
-
<!
|
2 |
-
<html>
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
</html>
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="vi">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>MLP Interactive Visualization</title>
|
7 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
8 |
+
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest/dist/tf.min.js"></script>
|
9 |
+
<link rel="preconnect" href="https://fonts.googleapis.com">
|
10 |
+
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
11 |
+
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&display=swap" rel="stylesheet">
|
12 |
+
<style>
|
13 |
+
body { font-family: 'Inter', sans-serif; background-color: #f1f5f9; }
|
14 |
+
|
15 |
+
/* Custom class for active class selection buttons */
|
16 |
+
.button-active-blue {
|
17 |
+
background-color: #2563eb !important;
|
18 |
+
color: white !important;
|
19 |
+
border-color: #2563eb !important;
|
20 |
+
box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
|
21 |
+
}
|
22 |
+
.button-active-red {
|
23 |
+
background-color: #dc2626 !important;
|
24 |
+
color: white !important;
|
25 |
+
border-color: #dc2626 !important;
|
26 |
+
box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
|
27 |
+
}
|
28 |
+
|
29 |
+
/* Cursor styles for data input */
|
30 |
+
#plotCanvas { touch-action: none; }
|
31 |
+
#plotCanvas.cursor-class-0 { cursor: url('data:image/svg+xml;utf8,<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="%233b82f6" stroke="white" stroke-width="2"><circle cx="12" cy="12" r="8"/><path d="M12 2v4M12 18v4M22 12h-4M6 12H2"/></svg>') 12 12, crosshair; }
|
32 |
+
#plotCanvas.cursor-class-1 { cursor: url('data:image/svg+xml;utf8,<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="%23ef4444" stroke="white" stroke-width="2"><circle cx="12" cy="12" r="8"/><path d="M12 2v4M12 18v4M22 12h-4M6 12H2"/></svg>') 12 12, crosshair; }
|
33 |
+
|
34 |
+
/* Scrollbar styling */
|
35 |
+
#hiddenLayersConfigContainer::-webkit-scrollbar { width: 6px; }
|
36 |
+
#hiddenLayersConfigContainer::-webkit-scrollbar-track { background: #e2e8f0; border-radius: 8px; }
|
37 |
+
#hiddenLayersConfigContainer::-webkit-scrollbar-thumb { background: #94a3b8; border-radius: 8px; }
|
38 |
+
#hiddenLayersConfigContainer::-webkit-scrollbar-thumb:hover { background: #64748b; }
|
39 |
+
|
40 |
+
/* Network Visualization SVG styles */
|
41 |
+
.neuron {
|
42 |
+
stroke-width: 1.5;
|
43 |
+
transition: stroke-width 0.2s ease-in-out;
|
44 |
+
}
|
45 |
+
.neuron:hover {
|
46 |
+
stroke-width: 4;
|
47 |
+
}
|
48 |
+
.neuron.input { fill: #60a5fa; stroke: #2563eb; }
|
49 |
+
.neuron.hidden-0 { fill: #818cf8; stroke: #4f46e5; }
|
50 |
+
.neuron.hidden-1 { fill: #a78bfa; stroke: #7c3aed; }
|
51 |
+
.neuron.hidden-2 { fill: #c084fc; stroke: #9333ea; }
|
52 |
+
.neuron.hidden-3 { fill: #e879f9; stroke: #c026d3; }
|
53 |
+
.neuron.hidden-other { fill: #f472b6; stroke: #db2777; }
|
54 |
+
.neuron.output { fill: #f87171; stroke: #dc2626; }
|
55 |
+
|
56 |
+
.connection {
|
57 |
+
stroke: #cbd5e1;
|
58 |
+
stroke-width: 0.75;
|
59 |
+
transition: stroke-opacity 0.2s;
|
60 |
+
}
|
61 |
+
|
62 |
+
.layer-label {
|
63 |
+
font-size: 11px;
|
64 |
+
font-weight: 600;
|
65 |
+
fill: #475569;
|
66 |
+
text-anchor: middle;
|
67 |
+
}
|
68 |
+
.neuron-count-label {
|
69 |
+
font-size: 10px;
|
70 |
+
font-weight: 500;
|
71 |
+
fill: #64748b;
|
72 |
+
text-anchor: middle;
|
73 |
+
}
|
74 |
+
|
75 |
+
/* Utility icon styles */
|
76 |
+
.title-icon { margin-left: 0.6rem; font-size: 1.1rem; }
|
77 |
+
.action-icon { margin-right: 0.35rem; }
|
78 |
+
.status-icon { margin-right: 0.5rem; flex-shrink: 0; font-size: 1.1rem; }
|
79 |
+
.loading-icon { animation: spin 1.5s linear infinite; display: inline-block; }
|
80 |
+
@keyframes spin { from { transform: rotate(0deg); } to { transform: rotate(360deg); } }
|
81 |
+
|
82 |
+
/* Confusion Matrix styles */
|
83 |
+
.cm-cell {
|
84 |
+
display: flex;
|
85 |
+
align-items: center;
|
86 |
+
justify-content: center;
|
87 |
+
flex-direction: column;
|
88 |
+
line-height: 1.2;
|
89 |
+
padding: 0.5rem;
|
90 |
+
border-radius: 0.25rem;
|
91 |
+
transition: all 0.2s ease;
|
92 |
+
}
|
93 |
+
.cm-value { font-size: 1.125rem; font-weight: 800; }
|
94 |
+
.cm-label { font-size: 0.65rem; text-transform: uppercase; letter-spacing: 0.05em; font-weight: 600; }
|
95 |
+
</style>
|
96 |
+
</head>
|
97 |
+
<body class="text-slate-800">
|
98 |
+
|
99 |
+
<div class="container mx-auto p-2 sm:p-4 max-w-full">
|
100 |
+
<header class="mb-4 text-center">
|
101 |
+
<h1 class="text-3xl sm:text-4xl font-extrabold text-blue-600">MLP Interactive Visualization</h1>
|
102 |
+
<p class="text-md text-slate-600 mt-1">Build, train, and visualize a Multi-Layer Perceptron.</p>
|
103 |
+
</header>
|
104 |
+
|
105 |
+
<!-- Main layout with reduced gap -->
|
106 |
+
<div class="flex flex-col lg:flex-row gap-2.5">
|
107 |
+
<!-- Control Panel -->
|
108 |
+
<div class="lg:w-6/12 bg-white p-3.5 rounded-2xl shadow-xl border border-slate-200 space-y-5">
|
109 |
+
<!-- Data & Training Section -->
|
110 |
+
<div class="space-y-4">
|
111 |
+
<h2 class="text-xl font-bold text-slate-700 border-b border-slate-200 pb-2 flex items-center">
|
112 |
+
1. Data & Training <span class="title-icon">⚙️</span>
|
113 |
+
</h2>
|
114 |
+
<div class="grid grid-cols-2 sm:grid-cols-4 gap-x-4 gap-y-3">
|
115 |
+
<div class="col-span-2 sm:col-span-4">
|
116 |
+
<label class="text-sm font-medium text-slate-700 block mb-1.5">Data Input Class:</label>
|
117 |
+
<div class="flex rounded-lg shadow-sm">
|
118 |
+
<button id="class0Button" class="flex-1 px-3 py-2 border border-slate-300 rounded-l-lg bg-white text-base font-semibold text-slate-700 hover:bg-slate-50 focus:z-10 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 transition-all duration-150" title="Select Class 0 for the next data points">Class 0 (Blue)</button>
|
119 |
+
<button id="class1Button" class="flex-1 px-3 py-2 border-t border-b border-r border-slate-300 rounded-r-lg bg-white text-base font-semibold text-slate-700 hover:bg-slate-50 focus:z-10 focus:outline-none focus:ring-2 focus:ring-red-500 focus:border-red-500 transition-all duration-150" title="Select Class 1 for the next data points">Class 1 (Red)</button>
|
120 |
+
</div>
|
121 |
+
</div>
|
122 |
+
<!-- Other controls... -->
|
123 |
+
<div>
|
124 |
+
<label for="datasetSelect" class="text-sm font-medium text-slate-700 block mb-1">Load Dataset:</label>
|
125 |
+
<select id="datasetSelect" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 bg-white rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="Choose a preset dataset.">
|
126 |
+
<option value="manual">Manual Input</option>
|
127 |
+
<option value="two_moons">Two Moons</option>
|
128 |
+
<option value="circles">Concentric Circles</option>
|
129 |
+
<option value="xor">XOR</option>
|
130 |
+
<option value="spiral">Spiral</option>
|
131 |
+
</select>
|
132 |
+
</div>
|
133 |
+
<div>
|
134 |
+
<label for="dataNoise" class="text-sm font-medium text-slate-700 block mb-1">Data Noise:</label>
|
135 |
+
<input type="number" id="dataNoise" value="0.05" min="0" max="0.5" step="0.01" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="Add random noise to the data points.">
|
136 |
+
</div>
|
137 |
+
<div>
|
138 |
+
<label for="learningRate" class="text-sm font-medium text-slate-700 block mb-1">Learning Rate:</label>
|
139 |
+
<input type="number" id="learningRate" value="0.01" step="0.001" min="0.00001" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="How fast the model learns.">
|
140 |
+
</div>
|
141 |
+
<div>
|
142 |
+
<label for="epochs" class="text-sm font-medium text-slate-700 block mb-1">Epochs:</label>
|
143 |
+
<input type="number" id="epochs" value="150" step="10" min="1" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="Number of training iterations.">
|
144 |
+
</div>
|
145 |
+
<div>
|
146 |
+
<label for="optimizerSelect" class="text-sm font-medium text-slate-700 block mb-1">Optimizer:</label>
|
147 |
+
<select id="optimizerSelect" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 bg-white rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="Algorithm to update weights.">
|
148 |
+
<option value="adam">Adam</option>
|
149 |
+
<option value="sgd">SGD</option>
|
150 |
+
<option value="rmsprop">RMSprop</option>
|
151 |
+
</select>
|
152 |
+
</div>
|
153 |
+
<div>
|
154 |
+
<label for="batchSize" class="text-sm font-medium text-slate-700 block mb-1">Batch Size:</label>
|
155 |
+
<input type="number" id="batchSize" value="16" min="0" step="4" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="Samples per weight update. 0=Full.">
|
156 |
+
</div>
|
157 |
+
<div>
|
158 |
+
<label for="regularizationTypeSelect" class="text-sm font-medium text-slate-700 block mb-1">Regularization:</label>
|
159 |
+
<select id="regularizationTypeSelect" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 bg-white rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="Technique to prevent overfitting.">
|
160 |
+
<option value="none">None</option>
|
161 |
+
<option value="l1">L1</option>
|
162 |
+
<option value="l2">L2</option>
|
163 |
+
</select>
|
164 |
+
</div>
|
165 |
+
<div id="regularizationRateContainer" class="hidden">
|
166 |
+
<label for="regularizationRateInput" class="text-sm font-medium text-slate-700 block mb-1">Reg. Rate (λ):</label>
|
167 |
+
<input type="number" id="regularizationRateInput" value="0.001" step="0.0001" min="0" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="Strength of the L1/L2 penalty.">
|
168 |
+
</div>
|
169 |
+
</div>
|
170 |
+
</div>
|
171 |
+
|
172 |
+
<!-- MLP Architecture Section -->
|
173 |
+
<div class="space-y-2">
|
174 |
+
<h2 class="text-xl font-bold text-slate-700 border-b border-slate-200 pb-2 flex items-center">
|
175 |
+
2. MLP Architecture <span class="title-icon">🧠</span>
|
176 |
+
</h2>
|
177 |
+
<div id="networkVisualization" class="bg-slate-50 rounded-lg p-2 min-h-[120px] border border-slate-200"></div>
|
178 |
+
<h3 class="text-sm font-medium text-slate-700 pt-2">Hidden Layers:</h3>
|
179 |
+
<div id="hiddenLayersConfigContainer" class="grid grid-cols-1 md:grid-cols-2 gap-2 max-h-48 overflow-y-auto pr-1.5"></div>
|
180 |
+
<button id="addHiddenLayerButton" class="mt-1 w-full flex justify-center items-center py-2 px-3 border-2 border-dashed border-blue-400 rounded-lg shadow-sm text-sm font-semibold text-blue-600 bg-blue-50 hover:bg-blue-100 hover:border-blue-500 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-blue-500 transition-all duration-150" title="Add a new hidden layer.">
|
181 |
+
<span class="action-icon text-base">➕</span> Add Layer
|
182 |
+
</button>
|
183 |
+
</div>
|
184 |
+
|
185 |
+
<!-- Actions Section -->
|
186 |
+
<div class="space-y-2">
|
187 |
+
<h2 class="text-xl font-bold text-slate-700 border-b border-slate-200 pb-2 flex items-center">
|
188 |
+
3. Actions <span class="title-icon">⚡️</span>
|
189 |
+
</h2>
|
190 |
+
<div class="grid grid-cols-2 sm:grid-cols-4 gap-2.5">
|
191 |
+
<button id="trainButton" class="col-span-2 flex justify-center items-center py-2.5 px-3 border border-transparent rounded-lg shadow-md text-base font-bold text-white bg-green-600 hover:bg-green-700 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-green-500 transition-all duration-150 disabled:opacity-50 disabled:cursor-not-allowed disabled:bg-green-400" title="Start or retrain the model.">
|
192 |
+
<span class="action-icon">▶️</span> <span id="trainButtonText">Train</span>
|
193 |
+
</button>
|
194 |
+
<button id="stopButton" class="hidden col-span-2 flex justify-center items-center py-2.5 px-3 border border-transparent rounded-lg shadow-md text-base font-bold text-white bg-red-600 hover:bg-red-700 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-red-500 transition-all" title="Stop the current training process.">
|
195 |
+
<span class="action-icon">⏹️</span> <span>Stop</span>
|
196 |
+
</button>
|
197 |
+
<button id="resetWeightsButton" class="flex justify-center items-center py-2.5 px-3 border border-slate-300 rounded-lg shadow-sm text-base font-semibold text-slate-700 bg-white hover:bg-slate-50 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-indigo-500 transition-all duration-150 disabled:opacity-50 disabled:cursor-not-allowed" title="Re-initialize the model's weights.">
|
198 |
+
<span class="action-icon">🔄</span> <span>Reset</span>
|
199 |
+
</button>
|
200 |
+
<button id="clearButton" class="flex justify-center items-center py-2.5 px-3 border border-slate-300 rounded-lg shadow-sm text-base font-semibold text-slate-700 bg-white hover:bg-slate-50 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-indigo-500 transition-all duration-150 disabled:opacity-50 disabled:cursor-not-allowed" title="Delete all data and reset settings.">
|
201 |
+
<span class="action-icon">🗑️</span> <span>Clear All</span>
|
202 |
+
</button>
|
203 |
+
</div>
|
204 |
+
</div>
|
205 |
+
</div>
|
206 |
+
|
207 |
+
<!-- Visualization Area with reduced gap -->
|
208 |
+
<div class="lg:w-6/12 flex flex-col gap-2.5">
|
209 |
+
<div class="bg-white p-2 rounded-2xl shadow-xl border border-slate-200">
|
210 |
+
<canvas id="plotCanvas" class="border border-slate-200 rounded-xl w-full"></canvas>
|
211 |
+
</div>
|
212 |
+
|
213 |
+
<div class="bg-white p-3.5 rounded-2xl shadow-xl border border-slate-200 flex flex-col flex-grow">
|
214 |
+
<h2 class="text-xl font-bold text-slate-700 mb-2 border-b border-slate-200 pb-2 flex items-center">
|
215 |
+
Training Status <span class="title-icon">📊</span>
|
216 |
+
</h2>
|
217 |
+
<div id="trainingParamsDisplay" class="text-xs mb-2 space-y-1"></div>
|
218 |
+
<div id="statusMessage" class="text-sm text-slate-600 p-2 bg-slate-50 rounded-lg min-h-[44px] mb-3 whitespace-pre-line leading-snug flex items-center justify-center border border-slate-200"></div>
|
219 |
+
<div class="flex-grow flex flex-col sm:flex-row gap-4">
|
220 |
+
<div class="w-full sm:w-2/3 relative min-h-[150px]">
|
221 |
+
<canvas id="trainingPlotCanvas" class="w-full h-full"></canvas>
|
222 |
+
</div>
|
223 |
+
<div id="confusionMatrixContainer" class="w-full sm:w-1/3">
|
224 |
+
<!-- Confusion Matrix will be rendered here -->
|
225 |
+
</div>
|
226 |
+
</div>
|
227 |
+
</div>
|
228 |
+
</div>
|
229 |
+
</div>
|
230 |
+
</div>
|
231 |
+
|
232 |
+
<script>
|
233 |
+
// --- App Namespace ---
|
234 |
+
// Encapsulate the entire application in a single object to avoid polluting the global namespace.
|
235 |
+
// This improves organization and prevents potential conflicts with other scripts.
|
236 |
+
const App = {
|
237 |
+
// --- STATE & CONFIG ---
|
238 |
+
state: {
|
239 |
+
model: null,
|
240 |
+
dataPoints: [],
|
241 |
+
currentClass: 0,
|
242 |
+
hiddenLayerConfigs: [],
|
243 |
+
trainingHistory: { loss: [], acc: [] },
|
244 |
+
isTraining: false,
|
245 |
+
},
|
246 |
+
ui: {}, // To hold DOM element references
|
247 |
+
config: {
|
248 |
+
pointRadius: 4.5,
|
249 |
+
classColors: {
|
250 |
+
0: { point: 'rgba(59, 130, 246, 1)', boundary: 'rgba(59, 130, 246, 0.3)' },
|
251 |
+
1: { point: 'rgba(239, 68, 68, 1)', boundary: 'rgba(239, 68, 68, 0.3)' }
|
252 |
+
},
|
253 |
+
statusIcons: {
|
254 |
+
info: 'ℹ️', success: '✅', warning: '⚠️', error: '❌', loading: '⏳'
|
255 |
+
},
|
256 |
+
},
|
257 |
+
|
258 |
+
// --- INITIALIZATION ---
|
259 |
+
async init() {
|
260 |
+
this.cacheUIElements();
|
261 |
+
this.registerEventListeners();
|
262 |
+
|
263 |
+
await tf.ready();
|
264 |
+
try {
|
265 |
+
await tf.setBackend('cpu');
|
266 |
+
console.log("TensorFlow.js backend set to CPU.");
|
267 |
+
} catch (e) { console.warn("Could not set TF.js backend to CPU.", e); }
|
268 |
+
|
269 |
+
this.methods.resizePlotCanvas();
|
270 |
+
this.initializeApplicationState();
|
271 |
+
},
|
272 |
+
|
273 |
+
cacheUIElements() {
|
274 |
+
const ids = [
|
275 |
+
'class0Button', 'class1Button', 'datasetSelect', 'dataNoise', 'batchSize',
|
276 |
+
'hiddenLayersConfigContainer', 'addHiddenLayerButton', 'networkVisualization',
|
277 |
+
'optimizerSelect', 'learningRate', 'epochs', 'regularizationTypeSelect',
|
278 |
+
'regularizationRateContainer', 'regularizationRateInput', 'trainButton',
|
279 |
+
'stopButton', 'resetWeightsButton', 'clearButton', 'statusMessage',
|
280 |
+
'plotCanvas', 'trainingPlotCanvas', 'trainingParamsDisplay', 'trainButtonText',
|
281 |
+
'confusionMatrixContainer'
|
282 |
+
];
|
283 |
+
ids.forEach(id => { this.ui[id] = document.getElementById(id); });
|
284 |
+
this.ui.canvas = this.ui.plotCanvas;
|
285 |
+
this.ui.ctx = this.ui.canvas.getContext('2d');
|
286 |
+
},
|
287 |
+
|
288 |
+
registerEventListeners() {
|
289 |
+
window.addEventListener('resize', () => {
|
290 |
+
this.methods.resizePlotCanvas();
|
291 |
+
this.methods.drawTrainingPlot();
|
292 |
+
this.methods.drawNetworkVisualization();
|
293 |
+
});
|
294 |
+
this.ui.canvas.addEventListener('click', (e) => this.methods.handleCanvasClick(e));
|
295 |
+
this.ui.class0Button.addEventListener('click', () => this.methods.setActiveClass(0));
|
296 |
+
this.ui.class1Button.addEventListener('click', () => this.methods.setActiveClass(1));
|
297 |
+
this.ui.addHiddenLayerButton.addEventListener('click', () => this.methods.addHiddenLayerUI());
|
298 |
+
this.ui.trainButton.addEventListener('click', () => this.methods.trainAndVisualize());
|
299 |
+
this.ui.stopButton.addEventListener('click', () => this.methods.stopTraining());
|
300 |
+
this.ui.resetWeightsButton.addEventListener('click', () => this.methods.resetModelWeights());
|
301 |
+
this.ui.clearButton.addEventListener('click', () => this.methods.clearAllDataAndReseed());
|
302 |
+
this.ui.datasetSelect.addEventListener('change', () => this.methods.loadSelectedDataset());
|
303 |
+
this.ui.dataNoise.addEventListener('input', () => { if (this.ui.datasetSelect.value !== 'manual') this.methods.loadSelectedDataset(); });
|
304 |
+
this.ui.regularizationTypeSelect.addEventListener('change', () => this.methods.toggleRegularizationRateInput());
|
305 |
+
},
|
306 |
+
|
307 |
+
initializeApplicationState() {
|
308 |
+
this.methods.setActiveClass(0);
|
309 |
+
this.methods.addHiddenLayerUI(8, 'relu'); // Default architecture
|
310 |
+
this.methods.addHiddenLayerUI(4, 'relu');
|
311 |
+
this.methods.toggleRegularizationRateInput();
|
312 |
+
this.methods.drawAll();
|
313 |
+
this.methods.drawTrainingPlot();
|
314 |
+
this.methods.updateTrainingParamsDisplay();
|
315 |
+
this.methods.updateButtonStates();
|
316 |
+
this.methods.renderConfusionMatrix(null);
|
317 |
+
this.methods.updateStatus('Ready. Click canvas to add points or load a dataset.', 'info');
|
318 |
+
},
|
319 |
+
|
320 |
+
// --- METHODS (Logic & Handlers) ---
|
321 |
+
methods: {
|
322 |
+
// --- UI & State Management ---
|
323 |
+
resizePlotCanvas() {
|
324 |
+
const dpr = window.devicePixelRatio || 1;
|
325 |
+
const rect = App.ui.canvas.parentElement.getBoundingClientRect();
|
326 |
+
if (rect.width === 0) return;
|
327 |
+
|
328 |
+
App.ui.canvas.width = rect.width * dpr;
|
329 |
+
const newHeight = Math.min(rect.width * 0.85, Math.max(300, window.innerHeight * 0.5));
|
330 |
+
App.ui.canvas.height = newHeight * dpr;
|
331 |
+
|
332 |
+
App.ui.ctx.scale(dpr, dpr);
|
333 |
+
App.ui.canvas.style.width = `${rect.width}px`;
|
334 |
+
App.ui.canvas.style.height = `${newHeight}px`;
|
335 |
+
|
336 |
+
App.ui.canvasWidth = rect.width;
|
337 |
+
App.ui.canvasHeight = newHeight;
|
338 |
+
this.drawAll();
|
339 |
+
},
|
340 |
+
|
341 |
+
setActiveClass(classNum) {
|
342 |
+
App.state.currentClass = classNum;
|
343 |
+
App.ui.class0Button.classList.toggle('button-active-blue', classNum === 0);
|
344 |
+
App.ui.class1Button.classList.toggle('button-active-red', classNum === 1);
|
345 |
+
App.ui.plotCanvas.className = App.ui.plotCanvas.className.replace(/cursor-class-\d/, '');
|
346 |
+
App.ui.plotCanvas.classList.add(`cursor-class-${classNum}`);
|
347 |
+
},
|
348 |
+
|
349 |
+
toggleRegularizationRateInput() {
|
350 |
+
App.ui.regularizationRateContainer.classList.toggle('hidden', App.ui.regularizationTypeSelect.value === 'none');
|
351 |
+
},
|
352 |
+
|
353 |
+
updateStatus(message, type = 'info') {
|
354 |
+
const icon = App.config.statusIcons[type] || '';
|
355 |
+
const loadingClass = type === 'loading' ? 'loading-icon' : '';
|
356 |
+
const iconHtml = `<span class="status-icon ${loadingClass}">${icon}</span>`;
|
357 |
+
App.ui.statusMessage.innerHTML = `${iconHtml}<span>${message}</span>`;
|
358 |
+
|
359 |
+
const typeToColor = {
|
360 |
+
info: 'text-slate-600 bg-slate-50 border-slate-200',
|
361 |
+
success: 'text-green-700 bg-green-50 border-green-200 font-semibold',
|
362 |
+
warning: 'text-amber-700 bg-amber-50 border-amber-200 font-semibold',
|
363 |
+
error: 'text-red-700 bg-red-50 border-red-200 font-semibold',
|
364 |
+
loading: 'text-blue-700 bg-blue-50 border-blue-200'
|
365 |
+
};
|
366 |
+
App.ui.statusMessage.className = `text-sm p-2 rounded-lg min-h-[44px] whitespace-pre-line leading-snug flex items-center justify-center border ${typeToColor[type] || ''}`;
|
367 |
+
},
|
368 |
+
|
369 |
+
setTrainingState(training) {
|
370 |
+
App.state.isTraining = training;
|
371 |
+
App.ui.trainButton.classList.toggle('hidden', training);
|
372 |
+
App.ui.stopButton.classList.toggle('hidden', !training);
|
373 |
+
this.updateButtonStates();
|
374 |
+
},
|
375 |
+
|
376 |
+
updateButtonStates() {
|
377 |
+
const hasData = App.state.dataPoints.length > 0;
|
378 |
+
const hasModel = App.state.model != null;
|
379 |
+
App.ui.trainButton.disabled = App.state.isTraining || !hasData;
|
380 |
+
App.ui.clearButton.disabled = App.state.isTraining;
|
381 |
+
App.ui.resetWeightsButton.disabled = App.state.isTraining || !hasModel;
|
382 |
+
|
383 |
+
if (hasModel && App.state.trainingHistory.loss.length > 0 && !App.state.isTraining) {
|
384 |
+
App.ui.trainButtonText.textContent = 'Retrain';
|
385 |
+
} else {
|
386 |
+
App.ui.trainButtonText.textContent = 'Train';
|
387 |
+
}
|
388 |
+
},
|
389 |
+
|
390 |
+
// --- Architecture UI ---
|
391 |
+
addHiddenLayerUI(defaultNeurons = 8, defaultActivation = 'relu') {
|
392 |
+
if (App.state.hiddenLayerConfigs.length >= 8) {
|
393 |
+
this.updateStatus("Max 8 hidden layers reached.", 'warning');
|
394 |
+
return;
|
395 |
+
}
|
396 |
+
const layerIndex = App.state.hiddenLayerConfigs.length;
|
397 |
+
App.state.hiddenLayerConfigs.push({ neurons: parseInt(defaultNeurons), activation: defaultActivation });
|
398 |
+
|
399 |
+
const layerDiv = document.createElement('div');
|
400 |
+
layerDiv.className = 'layer-config-item bg-slate-50 border border-slate-200 p-2 rounded-lg flex items-center justify-between gap-2 hover:bg-slate-100 transition-colors';
|
401 |
+
layerDiv.dataset.index = layerIndex;
|
402 |
+
|
403 |
+
layerDiv.innerHTML = `
|
404 |
+
<div class="flex-grow flex items-center gap-x-3">
|
405 |
+
<div class="flex items-center">
|
406 |
+
<label for="neurons_${layerIndex}" class="text-sm font-medium text-slate-600 mr-2 whitespace-nowrap">L${layerIndex + 1} Units:</label>
|
407 |
+
<input type="number" id="neurons_${layerIndex}" value="${defaultNeurons}" min="1" max="64" step="1" class="w-16 px-2 py-1 border border-slate-300 rounded-md shadow-sm text-sm focus:ring-blue-500 focus:border-blue-500">
|
408 |
+
</div>
|
409 |
+
<div class="flex items-center">
|
410 |
+
<label for="activation_${layerIndex}" class="text-sm font-medium text-slate-600 mr-2">Act:</label>
|
411 |
+
<select id="activation_${layerIndex}" class="flex-1 min-w-[90px] px-2 py-1 border bg-white border-slate-300 rounded-md shadow-sm text-sm focus:ring-blue-500 focus:border-blue-500">
|
412 |
+
${['relu', 'sigmoid', 'tanh', 'leakyRelu'].map(act => `<option value="${act}" ${act === defaultActivation ? 'selected' : ''}>${act.charAt(0).toUpperCase() + act.slice(1)}</option>`).join('')}
|
413 |
+
</select>
|
414 |
+
</div>
|
415 |
+
</div>
|
416 |
+
<button title="Remove layer" class="remove-btn flex-shrink-0 p-1.5 text-red-500 hover:text-red-700 rounded-full hover:bg-red-100 focus:outline-none focus:ring-2 focus:ring-red-500 focus:ring-offset-1 transition-all duration-150">
|
417 |
+
<span class="text-base font-bold">⛔</span>
|
418 |
+
</button>
|
419 |
+
`;
|
420 |
+
|
421 |
+
App.ui.hiddenLayersConfigContainer.appendChild(layerDiv);
|
422 |
+
|
423 |
+
layerDiv.querySelector(`#neurons_${layerIndex}`).onchange = (e) => {
|
424 |
+
const value = parseInt(e.target.value);
|
425 |
+
App.state.hiddenLayerConfigs[layerIndex].neurons = !isNaN(value) && value > 0 ? value : 1;
|
426 |
+
e.target.value = App.state.hiddenLayerConfigs[layerIndex].neurons;
|
427 |
+
this.drawNetworkVisualization();
|
428 |
+
};
|
429 |
+
layerDiv.querySelector(`#activation_${layerIndex}`).onchange = (e) => {
|
430 |
+
App.state.hiddenLayerConfigs[layerIndex].activation = e.target.value;
|
431 |
+
this.drawNetworkVisualization();
|
432 |
+
};
|
433 |
+
layerDiv.querySelector('.remove-btn').onclick = () => {
|
434 |
+
App.state.hiddenLayerConfigs.splice(layerIndex, 1);
|
435 |
+
this.redrawLayerConfigs();
|
436 |
+
};
|
437 |
+
this.drawNetworkVisualization();
|
438 |
+
},
|
439 |
+
|
440 |
+
redrawLayerConfigs() {
|
441 |
+
const configs = [...App.state.hiddenLayerConfigs];
|
442 |
+
App.ui.hiddenLayersConfigContainer.innerHTML = '';
|
443 |
+
App.state.hiddenLayerConfigs = [];
|
444 |
+
configs.forEach(config => this.addHiddenLayerUI(config.neurons, config.activation));
|
445 |
+
},
|
446 |
+
|
447 |
+
// --- Data Handling ---
|
448 |
+
handleCanvasClick(event) {
|
449 |
+
if (!App.ui.canvas) return;
|
450 |
+
const rect = App.ui.canvas.getBoundingClientRect();
|
451 |
+
const x = event.clientX - rect.left;
|
452 |
+
const y = event.clientY - rect.top;
|
453 |
+
|
454 |
+
App.state.dataPoints.push({ x, y, normX: x / App.ui.canvasWidth, normY: y / App.ui.canvasHeight, label: App.state.currentClass });
|
455 |
+
this.drawPoints();
|
456 |
+
this.updateStatus(`Added Class ${App.state.currentClass} point. Total: ${App.state.dataPoints.length}.`, 'info');
|
457 |
+
App.ui.datasetSelect.value = "manual";
|
458 |
+
this.updateButtonStates();
|
459 |
+
},
|
460 |
+
|
461 |
+
loadSelectedDataset() {
|
462 |
+
if (App.state.model) { App.state.model.dispose(); App.state.model = null; }
|
463 |
+
tf.disposeVariables();
|
464 |
+
App.state.dataPoints = [];
|
465 |
+
const datasetName = App.ui.datasetSelect.value;
|
466 |
+
if (datasetName === 'manual') {
|
467 |
+
this.drawAll();
|
468 |
+
this.updateButtonStates();
|
469 |
+
return;
|
470 |
+
}
|
471 |
+
const noise = parseFloat(App.ui.dataNoise.value) || 0;
|
472 |
+
const nSamples = 150;
|
473 |
+
|
474 |
+
const generators = {
|
475 |
+
two_moons: (n, noise) => {
|
476 |
+
const n_per_moon = Math.floor(n / 2);
|
477 |
+
const radius = 0.3;
|
478 |
+
for (let i = 0; i < n_per_moon; i++) {
|
479 |
+
const angle = (i / n_per_moon) * Math.PI;
|
480 |
+
// First moon, shifted left and up
|
481 |
+
this.addDataPoint(
|
482 |
+
0.5 + radius * Math.cos(angle) - 0.125,
|
483 |
+
0.5 + radius * Math.sin(angle) + 0.1,
|
484 |
+
0, noise
|
485 |
+
);
|
486 |
+
// Second moon, shifted right and down
|
487 |
+
this.addDataPoint(
|
488 |
+
0.5 + radius * Math.cos(angle + Math.PI) + 0.125,
|
489 |
+
0.5 + radius * Math.sin(angle + Math.PI) - 0.1,
|
490 |
+
1, noise
|
491 |
+
);
|
492 |
+
}
|
493 |
+
},
|
494 |
+
circles: (n, noise) => { for (let i=0; i<n; i++) { const r=Math.random(), a=Math.random()*2*Math.PI, l=r<0.5?0:1, rs=l===0?0.2:0.4; this.addDataPoint(0.5+rs*Math.cos(a), 0.5+rs*Math.sin(a), l, noise); } },
|
495 |
+
xor: (n, noise) => { const q=Math.floor(n/4), s=0.3; for (let i=0; i<q; i++) { this.addDataPoint(0.5-s, 0.5-s, 0, noise*2); this.addDataPoint(0.5+s, 0.5-s, 1, noise*2); this.addDataPoint(0.5-s, 0.5+s, 1, noise*2); this.addDataPoint(0.5+s, 0.5+s, 0, noise*2); } },
|
496 |
+
spiral: (n, noise) => { const pts=Math.floor(n/2); for (let i=0; i<pts; i++) { const a=i/20*Math.PI, r=0.05+i/pts*0.4; this.addDataPoint(0.5+r*Math.cos(a), 0.5+r*Math.sin(a), 0, noise*0.5); this.addDataPoint(0.5+r*Math.cos(a+Math.PI), 0.5+r*Math.sin(a+Math.PI), 1, noise*0.5); } }
|
497 |
+
};
|
498 |
+
|
499 |
+
generators[datasetName](nSamples, noise);
|
500 |
+
App.state.dataPoints.forEach(p => { p.x = p.normX * App.ui.canvasWidth; p.y = p.normY * App.ui.canvasHeight; });
|
501 |
+
this.drawAll();
|
502 |
+
this.updateStatus(`Loaded '${datasetName}' dataset. Noise: ${noise}.`, 'info');
|
503 |
+
|
504 |
+
App.state.trainingHistory = { loss: [], acc: [] };
|
505 |
+
this.drawTrainingPlot();
|
506 |
+
this.renderConfusionMatrix(null);
|
507 |
+
this.updateButtonStates();
|
508 |
+
this.updateTrainingParamsDisplay();
|
509 |
+
},
|
510 |
+
|
511 |
+
addDataPoint(normX, normY, label, noise) {
|
512 |
+
App.state.dataPoints.push({
|
513 |
+
x: 0, y: 0,
|
514 |
+
normX: normX + (Math.random() - 0.5) * noise,
|
515 |
+
normY: normY + (Math.random() - 0.5) * noise,
|
516 |
+
label: label
|
517 |
+
});
|
518 |
+
},
|
519 |
+
|
520 |
+
// --- Drawing & Visualization ---
|
521 |
+
drawAll(boundaryGrid = null) {
|
522 |
+
if (!App.ui.ctx || !App.ui.canvasWidth || !App.ui.canvasHeight) return;
|
523 |
+
App.ui.ctx.clearRect(0, 0, App.ui.canvasWidth, App.ui.canvasHeight);
|
524 |
+
if (boundaryGrid) {
|
525 |
+
const { grid, resolution } = boundaryGrid;
|
526 |
+
for (let i = 0; i < grid.length; i++) {
|
527 |
+
for (let j = 0; j < grid[i].length; j++) {
|
528 |
+
App.ui.ctx.fillStyle = App.config.classColors[grid[i][j]].boundary;
|
529 |
+
App.ui.ctx.fillRect(j * resolution, i * resolution, resolution, resolution);
|
530 |
+
}
|
531 |
+
}
|
532 |
+
}
|
533 |
+
App.state.dataPoints.forEach(point => {
|
534 |
+
App.ui.ctx.beginPath();
|
535 |
+
App.ui.ctx.arc(point.x, point.y, App.config.pointRadius, 0, 2 * Math.PI);
|
536 |
+
App.ui.ctx.fillStyle = App.config.classColors[point.label].point;
|
537 |
+
App.ui.ctx.fill();
|
538 |
+
App.ui.ctx.strokeStyle = 'rgba(255,255,255,0.7)';
|
539 |
+
App.ui.ctx.lineWidth = 1.5;
|
540 |
+
App.ui.ctx.stroke();
|
541 |
+
});
|
542 |
+
},
|
543 |
+
|
544 |
+
drawPoints() { this.drawAll(); },
|
545 |
+
|
546 |
+
drawNetworkVisualization() {
|
547 |
+
const container = App.ui.networkVisualization;
|
548 |
+
container.innerHTML = '';
|
549 |
+
const allLayers = [
|
550 |
+
{ type: 'input', neurons: 2, activation: 'Input' },
|
551 |
+
...App.state.hiddenLayerConfigs.map((cfg, i) => ({ type: `hidden-${i % 5}`, neurons: cfg.neurons, activation: cfg.activation })),
|
552 |
+
{ type: 'output', neurons: 1, activation: 'Sigmoid' }
|
553 |
+
];
|
554 |
+
|
555 |
+
const svg = document.createElementNS("http://www.w3.org/2000/svg", "svg");
|
556 |
+
const rect = container.getBoundingClientRect();
|
557 |
+
if(rect.width === 0 || rect.height === 0) return;
|
558 |
+
svg.setAttribute('viewBox', `0 0 ${rect.width} ${rect.height}`);
|
559 |
+
|
560 |
+
const margin = { top: 25, right: 15, bottom: 20, left: 15 };
|
561 |
+
const width = rect.width - margin.left - margin.right;
|
562 |
+
const height = rect.height - margin.top - margin.bottom;
|
563 |
+
const layerSpacing = allLayers.length > 1 ? width / (allLayers.length - 1) : width;
|
564 |
+
|
565 |
+
// Connections
|
566 |
+
for (let i = 0; i < allLayers.length - 1; i++) {
|
567 |
+
const x1 = margin.left + i * layerSpacing;
|
568 |
+
const x2 = margin.left + (i + 1) * layerSpacing;
|
569 |
+
const maxNeurons = 8;
|
570 |
+
const currentNeurons = Math.min(allLayers[i].neurons, maxNeurons);
|
571 |
+
const nextNeurons = Math.min(allLayers[i+1].neurons, maxNeurons);
|
572 |
+
for (let j = 0; j < currentNeurons; j++) {
|
573 |
+
const y1 = margin.top + height * ((j + 0.5) / currentNeurons);
|
574 |
+
for (let k = 0; k < nextNeurons; k++) {
|
575 |
+
const y2 = margin.top + height * ((k + 0.5) / nextNeurons);
|
576 |
+
const line = document.createElementNS("http://www.w3.org/2000/svg", "line");
|
577 |
+
line.setAttribute('x1', x1); line.setAttribute('y1', y1);
|
578 |
+
line.setAttribute('x2', x2); line.setAttribute('y2', y2);
|
579 |
+
line.setAttribute('class', 'connection');
|
580 |
+
svg.appendChild(line);
|
581 |
+
}
|
582 |
+
}
|
583 |
+
}
|
584 |
+
// Neurons and Labels
|
585 |
+
allLayers.forEach((layer, i) => {
|
586 |
+
const x = margin.left + i * layerSpacing;
|
587 |
+
const maxNeurons = 8;
|
588 |
+
const displayNeurons = Math.min(layer.neurons, maxNeurons);
|
589 |
+
const neuronRadius = Math.max(3, Math.min(7, height / (displayNeurons * 2.5)));
|
590 |
+
for (let j = 0; j < displayNeurons; j++) {
|
591 |
+
const y = margin.top + height * ((j + 0.5) / displayNeurons);
|
592 |
+
const circle = document.createElementNS("http://www.w3.org/2000/svg", "circle");
|
593 |
+
circle.setAttribute('cx', x); circle.setAttribute('cy', y);
|
594 |
+
circle.setAttribute('r', neuronRadius);
|
595 |
+
circle.setAttribute('class', `neuron ${layer.type}`);
|
596 |
+
svg.appendChild(circle);
|
597 |
+
}
|
598 |
+
const labelText = layer.activation.charAt(0).toUpperCase() + layer.activation.slice(1);
|
599 |
+
const textLabel = document.createElementNS("http://www.w3.org/2000/svg", "text");
|
600 |
+
textLabel.setAttribute('x', x); textLabel.setAttribute('y', margin.top - 8);
|
601 |
+
textLabel.setAttribute('class', 'layer-label');
|
602 |
+
textLabel.textContent = labelText;
|
603 |
+
svg.appendChild(textLabel);
|
604 |
+
|
605 |
+
const countLabel = document.createElementNS("http://www.w3.org/2000/svg", "text");
|
606 |
+
countLabel.setAttribute('x', x); countLabel.setAttribute('y', margin.top + height + 15);
|
607 |
+
countLabel.setAttribute('class', 'neuron-count-label');
|
608 |
+
countLabel.textContent = `${layer.neurons} N`;
|
609 |
+
svg.appendChild(countLabel);
|
610 |
+
});
|
611 |
+
container.appendChild(svg);
|
612 |
+
},
|
613 |
+
|
614 |
+
drawTrainingPlot() {
|
615 |
+
const canvas = App.ui.trainingPlotCanvas;
|
616 |
+
const ctx = canvas.getContext('2d');
|
617 |
+
const dpr = window.devicePixelRatio || 1;
|
618 |
+
const rect = canvas.getBoundingClientRect();
|
619 |
+
if (rect.width === 0 || rect.height === 0) return;
|
620 |
+
canvas.width = rect.width * dpr;
|
621 |
+
canvas.height = rect.height * dpr;
|
622 |
+
ctx.scale(dpr, dpr);
|
623 |
+
const { width, height } = rect;
|
624 |
+
const padding = {top: 20, right: 15, bottom: 20, left: 30};
|
625 |
+
|
626 |
+
ctx.fillStyle = '#f8fafc';
|
627 |
+
ctx.fillRect(0,0,width,height);
|
628 |
+
|
629 |
+
if (App.state.trainingHistory.loss.length === 0) {
|
630 |
+
ctx.fillStyle = '#64748b';
|
631 |
+
ctx.textAlign = 'center';
|
632 |
+
ctx.font = '12px Inter';
|
633 |
+
ctx.fillText('Training history will be plotted here.', width / 2, height / 2);
|
634 |
+
return;
|
635 |
+
}
|
636 |
+
|
637 |
+
ctx.beginPath();
|
638 |
+
ctx.strokeStyle = '#e2e8f0';
|
639 |
+
ctx.lineWidth = 1;
|
640 |
+
for(let i = 0; i <= 4; i++){
|
641 |
+
const y = padding.top + i * (height - padding.top - padding.bottom) / 4;
|
642 |
+
ctx.moveTo(padding.left, y);
|
643 |
+
ctx.lineTo(width-padding.right, y);
|
644 |
+
}
|
645 |
+
ctx.stroke();
|
646 |
+
|
647 |
+
ctx.font = '10px Inter';
|
648 |
+
ctx.fillStyle = '#475569';
|
649 |
+
ctx.textAlign = 'right';
|
650 |
+
for(let i = 0; i <= 4; i++){
|
651 |
+
ctx.fillText((1 - i/4).toFixed(1), padding.left - 6, padding.top + 3 + i * (height - padding.top - padding.bottom) / 4);
|
652 |
+
}
|
653 |
+
|
654 |
+
const plotData = (data, color) => {
|
655 |
+
ctx.beginPath(); ctx.strokeStyle = color; ctx.lineWidth = 2; ctx.lineJoin = 'round'; ctx.lineCap = 'round';
|
656 |
+
data.forEach((val, i) => {
|
657 |
+
const x = padding.left + (i / (Math.max(1, data.length -1))) * (width - padding.left - padding.right);
|
658 |
+
const y = (height - padding.bottom) - Math.min(Math.max(val,0.0),1.0) * (height - padding.top - padding.bottom);
|
659 |
+
if (i === 0) ctx.moveTo(x, y); else ctx.lineTo(x, y);
|
660 |
+
});
|
661 |
+
ctx.stroke();
|
662 |
+
};
|
663 |
+
plotData(App.state.trainingHistory.loss, 'rgba(239, 68, 68, 0.9)');
|
664 |
+
plotData(App.state.trainingHistory.acc, 'rgba(37, 99, 235, 0.9)');
|
665 |
+
|
666 |
+
ctx.textAlign = 'left';
|
667 |
+
ctx.font = '600 11px Inter';
|
668 |
+
ctx.fillStyle = 'rgba(239, 68, 68, 1)'; ctx.fillRect(padding.left + 5, 5, 10, 3);
|
669 |
+
ctx.fillStyle = '#374151'; ctx.fillText('Loss', padding.left + 20, 10);
|
670 |
+
ctx.fillStyle = 'rgba(37, 99, 235, 1)'; ctx.fillRect(padding.left + 75, 5, 10, 3);
|
671 |
+
ctx.fillStyle = '#374151'; ctx.fillText('Accuracy', padding.left + 90, 10);
|
672 |
+
},
|
673 |
+
|
674 |
+
updateTrainingParamsDisplay() {
|
675 |
+
const container = App.ui.trainingParamsDisplay;
|
676 |
+
if (!App.state.model) {
|
677 |
+
container.innerHTML = `<div class="text-slate-500">Train a model to see parameters.</div>`;
|
678 |
+
return;
|
679 |
+
}
|
680 |
+
const lr = parseFloat(App.ui.learningRate.value);
|
681 |
+
const regType = App.ui.regularizationTypeSelect.value;
|
682 |
+
const regRate = parseFloat(App.ui.regularizationRateInput.value) || 0;
|
683 |
+
let regDesc = regType !== 'none' ? `${regType.toUpperCase()}(λ=${regRate})` : 'None';
|
684 |
+
let hiddenDesc = App.state.hiddenLayerConfigs.map(l => l.neurons).join(' → ');
|
685 |
+
|
686 |
+
container.innerHTML = `
|
687 |
+
<div class="flex flex-wrap gap-x-4 gap-y-1">
|
688 |
+
<span><span class="font-semibold text-slate-500">Opt:</span> <span class="font-medium text-slate-800">${App.ui.optimizerSelect.value.toUpperCase()}</span></span>
|
689 |
+
<span><span class="font-semibold text-slate-500">LR:</span> <span class="font-medium text-slate-800">${lr.toExponential(1)}</span></span>
|
690 |
+
<span><span class="font-semibold text-slate-500">Batch:</span> <span class="font-medium text-slate-800">${parseInt(App.ui.batchSize.value) === 0 ? 'Full' : App.ui.batchSize.value}</span></span>
|
691 |
+
<span><span class="font-semibold text-slate-500">Reg:</span> <span class="font-medium text-slate-800">${regDesc}</span></span>
|
692 |
+
</div>
|
693 |
+
<div><span class="font-semibold text-slate-500">Layers:</span> <span class="font-medium text-slate-800">2 → ${hiddenDesc || '...'} → 1</span></div>
|
694 |
+
`;
|
695 |
+
},
|
696 |
+
|
697 |
+
renderConfusionMatrix(matrix) {
|
698 |
+
const container = App.ui.confusionMatrixContainer;
|
699 |
+
if (!matrix) {
|
700 |
+
container.innerHTML = `<div class="flex items-center justify-center h-full text-sm text-center text-slate-400 bg-slate-50 rounded-lg p-2 border border-slate-200">Confusion matrix appears here after training.</div>`;
|
701 |
+
return;
|
702 |
+
}
|
703 |
+
container.innerHTML = `
|
704 |
+
<div class="h-full flex flex-col">
|
705 |
+
<h4 class="text-sm font-semibold text-center text-slate-600 mb-1.5">Confusion Matrix</h4>
|
706 |
+
<div class="grid grid-cols-2 grid-rows-2 gap-1.5 flex-grow">
|
707 |
+
<div class="cm-cell bg-blue-100 text-blue-800" title="True Negative"><span class="cm-value">${matrix.tn}</span><span class="cm-label">TN</span></div>
|
708 |
+
<div class="cm-cell bg-red-100 text-red-800" title="False Positive"><span class="cm-value">${matrix.fp}</span><span class="cm-label">FP</span></div>
|
709 |
+
<div class="cm-cell bg-red-100 text-red-800" title="False Negative"><span class="cm-value">${matrix.fn}</span><span class="cm-label">FN</span></div>
|
710 |
+
<div class="cm-cell bg-green-100 text-green-800" title="True Positive"><span class="cm-value">${matrix.tp}</span><span class="cm-label">TP</span></div>
|
711 |
+
</div>
|
712 |
+
</div>
|
713 |
+
`;
|
714 |
+
},
|
715 |
+
|
716 |
+
// --- TENSORFLOW.JS & ML ---
|
717 |
+
getSafeNumericInput(element, defaultValue, isInteger = true) {
|
718 |
+
let value = isInteger ? parseInt(element.value, 10) : parseFloat(element.value);
|
719 |
+
if (isNaN(value)) {
|
720 |
+
value = defaultValue;
|
721 |
+
element.value = defaultValue;
|
722 |
+
}
|
723 |
+
return value;
|
724 |
+
},
|
725 |
+
|
726 |
+
buildModel() {
|
727 |
+
if (App.state.model) { App.state.model.dispose(); App.state.model = null; }
|
728 |
+
tf.disposeVariables();
|
729 |
+
|
730 |
+
const learningRate = this.getSafeNumericInput(App.ui.learningRate, 0.01, false);
|
731 |
+
const regType = App.ui.regularizationTypeSelect.value;
|
732 |
+
const regRate = this.getSafeNumericInput(App.ui.regularizationRateInput, 0, false);
|
733 |
+
|
734 |
+
const kernelRegularizer = (regType !== 'none' && regRate > 0)
|
735 |
+
? tf.regularizers[regType]({[regType]: regRate})
|
736 |
+
: null;
|
737 |
+
|
738 |
+
App.state.model = tf.sequential();
|
739 |
+
const inputShape = [2];
|
740 |
+
// Add hidden layers
|
741 |
+
App.state.hiddenLayerConfigs.forEach((layerConfig, index) => {
|
742 |
+
App.state.model.add(tf.layers.dense({
|
743 |
+
units: layerConfig.neurons,
|
744 |
+
inputShape: index === 0 ? inputShape : undefined,
|
745 |
+
activation: layerConfig.activation,
|
746 |
+
kernelRegularizer
|
747 |
+
}));
|
748 |
+
});
|
749 |
+
// Add output layer
|
750 |
+
App.state.model.add(tf.layers.dense({
|
751 |
+
units: 1,
|
752 |
+
activation: 'sigmoid',
|
753 |
+
inputShape: App.state.hiddenLayerConfigs.length === 0 ? inputShape : undefined,
|
754 |
+
}));
|
755 |
+
|
756 |
+
const optimizerInstance = tf.train[App.ui.optimizerSelect.value](learningRate);
|
757 |
+
App.state.model.compile({ optimizer: optimizerInstance, loss: 'binaryCrossentropy', metrics: ['accuracy'] });
|
758 |
+
App.state.model.stopTraining = false;
|
759 |
+
},
|
760 |
+
|
761 |
+
async trainAndVisualize() {
|
762 |
+
if (App.state.isTraining) return;
|
763 |
+
const uniqueLabels = new Set(App.state.dataPoints.map(p => p.label));
|
764 |
+
if (App.state.dataPoints.length < 4 || uniqueLabels.size < 2) {
|
765 |
+
this.updateStatus('Requires at least 4 points and data from both classes.', 'error');
|
766 |
+
return;
|
767 |
+
}
|
768 |
+
|
769 |
+
this.setTrainingState(true);
|
770 |
+
this.updateStatus('Starting training...', 'loading');
|
771 |
+
this.renderConfusionMatrix(null);
|
772 |
+
await tf.nextFrame();
|
773 |
+
|
774 |
+
this.buildModel();
|
775 |
+
this.updateTrainingParamsDisplay();
|
776 |
+
|
777 |
+
const epochs = this.getSafeNumericInput(App.ui.epochs, 150, true);
|
778 |
+
let batchSize = this.getSafeNumericInput(App.ui.batchSize, 16, true);
|
779 |
+
if (batchSize === 0) batchSize = App.state.dataPoints.length;
|
780 |
+
|
781 |
+
const [xs, ys] = tf.tidy(() => {
|
782 |
+
const normalized = App.state.dataPoints.map(p => [p.normX, p.normY]);
|
783 |
+
const labels = App.state.dataPoints.map(p => p.label);
|
784 |
+
return [tf.tensor2d(normalized), tf.tensor2d(labels, [labels.length, 1])];
|
785 |
+
});
|
786 |
+
|
787 |
+
App.state.trainingHistory = { loss: [], acc: [] };
|
788 |
+
this.drawTrainingPlot();
|
789 |
+
|
790 |
+
try {
|
791 |
+
await App.state.model.fit(xs, ys, {
|
792 |
+
epochs, batchSize,
|
793 |
+
callbacks: {
|
794 |
+
onEpochEnd: async (epoch, logs) => {
|
795 |
+
if (App.state.model.stopTraining) { App.state.model.stop(); return; }
|
796 |
+
this.updateStatus(`Epoch ${epoch + 1}/${epochs} - Loss: ${logs.loss.toFixed(4)}, Acc: ${logs.acc.toFixed(4)}`, 'loading');
|
797 |
+
App.state.trainingHistory.loss.push(logs.loss);
|
798 |
+
App.state.trainingHistory.acc.push(logs.acc);
|
799 |
+
this.drawTrainingPlot();
|
800 |
+
|
801 |
+
if ((epoch + 1) % Math.max(1, Math.floor(epochs / 25)) === 0) {
|
802 |
+
const boundaryGrid = await this.generateBoundaryGrid();
|
803 |
+
this.drawAll(boundaryGrid);
|
804 |
+
}
|
805 |
+
await tf.nextFrame();
|
806 |
+
},
|
807 |
+
onTrainEnd: async () => {
|
808 |
+
const finalAcc = App.state.trainingHistory.acc.slice(-1)[0] || 0;
|
809 |
+
if (!App.state.model.stopTraining) {
|
810 |
+
this.updateStatus(`Training complete! Final Accuracy: <b>${(finalAcc*100).toFixed(2)}%</b>`, 'success');
|
811 |
+
}
|
812 |
+
const boundaryGrid = await this.generateBoundaryGrid();
|
813 |
+
this.drawAll(boundaryGrid);
|
814 |
+
|
815 |
+
const confusionMatrix = await this.calculateConfusionMatrix(xs, ys);
|
816 |
+
this.renderConfusionMatrix(confusionMatrix);
|
817 |
+
this.setTrainingState(false);
|
818 |
+
}
|
819 |
+
}
|
820 |
+
});
|
821 |
+
} catch (error) {
|
822 |
+
console.error("Training error:", error);
|
823 |
+
this.updateStatus(`Training error: ${error.message}`, 'error');
|
824 |
+
this.setTrainingState(false);
|
825 |
+
} finally {
|
826 |
+
xs.dispose();
|
827 |
+
ys.dispose();
|
828 |
+
}
|
829 |
+
},
|
830 |
+
|
831 |
+
async calculateConfusionMatrix(xs, ys) {
|
832 |
+
if (!App.state.model) return null;
|
833 |
+
return tf.tidy(() => {
|
834 |
+
const predictions = App.state.model.predict(xs).round();
|
835 |
+
const tp = ys.mul(predictions).sum().dataSync()[0];
|
836 |
+
const tn = ys.sub(1).mul(-1).mul(predictions.sub(1).mul(-1)).sum().dataSync()[0];
|
837 |
+
const fp = predictions.sub(ys).relu().sum().dataSync()[0];
|
838 |
+
const fn = ys.sub(predictions).relu().sum().dataSync()[0];
|
839 |
+
return { tp, tn, fp, fn };
|
840 |
+
});
|
841 |
+
},
|
842 |
+
|
843 |
+
async generateBoundaryGrid() {
|
844 |
+
if (!App.state.model || !App.ui.canvasWidth || !App.ui.canvasHeight) return null;
|
845 |
+
const resolution = Math.max(5, Math.floor(App.ui.canvasWidth / 80));
|
846 |
+
const numCols = Math.floor(App.ui.canvasWidth / resolution);
|
847 |
+
const numRows = Math.floor(App.ui.canvasHeight / resolution);
|
848 |
+
if (numCols <= 0 || numRows <= 0) return null;
|
849 |
+
|
850 |
+
const boundaryData = tf.tidy(() => {
|
851 |
+
const gridPoints = [];
|
852 |
+
for (let i = 0; i < numRows; i++) {
|
853 |
+
for (let j = 0; j < numCols; j++) {
|
854 |
+
gridPoints.push([(j * resolution) / App.ui.canvasWidth, (i * resolution) / App.ui.canvasHeight]);
|
855 |
+
}
|
856 |
+
}
|
857 |
+
const predsTensor = App.state.model.predict(tf.tensor2d(gridPoints));
|
858 |
+
return predsTensor.dataSync(); // Use dataSync inside tidy
|
859 |
+
});
|
860 |
+
|
861 |
+
const grid = []; let k = 0;
|
862 |
+
for (let i = 0; i < numRows; i++) {
|
863 |
+
const row = [];
|
864 |
+
for (let j = 0; j < numCols; j++) row.push(boundaryData[k++] > 0.5 ? 1 : 0);
|
865 |
+
grid.push(row);
|
866 |
+
}
|
867 |
+
return { grid, resolution };
|
868 |
+
},
|
869 |
+
|
870 |
+
stopTraining() {
|
871 |
+
if (App.state.model) {
|
872 |
+
App.state.model.stopTraining = true;
|
873 |
+
this.updateStatus("Training stopped by user.", 'warning');
|
874 |
+
}
|
875 |
+
},
|
876 |
+
|
877 |
+
resetModelWeights() {
|
878 |
+
if (App.state.isTraining) return;
|
879 |
+
this.updateStatus("Model weights reset. Ready to train again.", 'info');
|
880 |
+
this.buildModel();
|
881 |
+
this.drawAll();
|
882 |
+
App.state.trainingHistory = { loss: [], acc: [] };
|
883 |
+
this.drawTrainingPlot();
|
884 |
+
this.renderConfusionMatrix(null);
|
885 |
+
this.updateButtonStates();
|
886 |
+
},
|
887 |
+
|
888 |
+
clearAllDataAndReseed() {
|
889 |
+
if (App.state.isTraining) return;
|
890 |
+
App.state.dataPoints = [];
|
891 |
+
if(App.state.model) { App.state.model.dispose(); App.state.model = null; }
|
892 |
+
tf.disposeVariables();
|
893 |
+
App.state.hiddenLayerConfigs = [];
|
894 |
+
App.ui.hiddenLayersConfigContainer.innerHTML = '';
|
895 |
+
|
896 |
+
App.initializeApplicationState();
|
897 |
+
this.updateStatus('Cleared all data and reset configuration.', 'info');
|
898 |
+
}
|
899 |
+
}
|
900 |
+
};
|
901 |
+
|
902 |
+
// --- Entry Point ---
|
903 |
+
document.addEventListener('DOMContentLoaded', () => App.init());
|
904 |
+
</script>
|
905 |
+
</body>
|
906 |
</html>
|