duongtruongbinh's picture
Update index.html
439a5cd verified
<!DOCTYPE html>
<html lang="vi">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>MLP Interactive Visualization</title>
<script src="https://cdn.tailwindcss.com"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest/dist/tf.min.js"></script>
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&display=swap" rel="stylesheet">
<style>
body { font-family: 'Inter', sans-serif; background-color: #f1f5f9; }
/* Custom class for active class selection buttons */
.button-active-blue {
background-color: #2563eb !important;
color: white !important;
border-color: #2563eb !important;
box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
}
.button-active-red {
background-color: #dc2626 !important;
color: white !important;
border-color: #dc2626 !important;
box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
}
/* Cursor styles for data input */
#plotCanvas { touch-action: none; }
#plotCanvas.cursor-class-0 { cursor: url('data:image/svg+xml;utf8,<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="%233b82f6" stroke="white" stroke-width="2"><circle cx="12" cy="12" r="8"/><path d="M12 2v4M12 18v4M22 12h-4M6 12H2"/></svg>') 12 12, crosshair; }
#plotCanvas.cursor-class-1 { cursor: url('data:image/svg+xml;utf8,<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="%23ef4444" stroke="white" stroke-width="2"><circle cx="12" cy="12" r="8"/><path d="M12 2v4M12 18v4M22 12h-4M6 12H2"/></svg>') 12 12, crosshair; }
/* Scrollbar styling */
#hiddenLayersConfigContainer::-webkit-scrollbar { width: 6px; }
#hiddenLayersConfigContainer::-webkit-scrollbar-track { background: #e2e8f0; border-radius: 8px; }
#hiddenLayersConfigContainer::-webkit-scrollbar-thumb { background: #94a3b8; border-radius: 8px; }
#hiddenLayersConfigContainer::-webkit-scrollbar-thumb:hover { background: #64748b; }
/* Network Visualization SVG styles */
.neuron {
stroke-width: 1.5;
transition: stroke-width 0.2s ease-in-out;
}
.neuron:hover {
stroke-width: 4;
}
.neuron.input { fill: #60a5fa; stroke: #2563eb; }
.neuron.hidden-0 { fill: #818cf8; stroke: #4f46e5; }
.neuron.hidden-1 { fill: #a78bfa; stroke: #7c3aed; }
.neuron.hidden-2 { fill: #c084fc; stroke: #9333ea; }
.neuron.hidden-3 { fill: #e879f9; stroke: #c026d3; }
.neuron.hidden-other { fill: #f472b6; stroke: #db2777; }
.neuron.output { fill: #f87171; stroke: #dc2626; }
.connection {
stroke: #cbd5e1;
stroke-width: 0.75;
transition: stroke-opacity 0.2s;
}
.layer-label {
font-size: 11px;
font-weight: 600;
fill: #475569;
text-anchor: middle;
}
.neuron-count-label {
font-size: 10px;
font-weight: 500;
fill: #64748b;
text-anchor: middle;
}
/* Utility icon styles */
.title-icon { margin-left: 0.6rem; font-size: 1.1rem; }
.action-icon { margin-right: 0.35rem; }
.status-icon { margin-right: 0.5rem; flex-shrink: 0; font-size: 1.1rem; }
.loading-icon { animation: spin 1.5s linear infinite; display: inline-block; }
@keyframes spin { from { transform: rotate(0deg); } to { transform: rotate(360deg); } }
/* Confusion Matrix styles */
.cm-cell {
display: flex;
align-items: center;
justify-content: center;
flex-direction: column;
line-height: 1.2;
padding: 0.5rem;
border-radius: 0.25rem;
transition: all 0.2s ease;
}
.cm-value { font-size: 1.125rem; font-weight: 800; }
.cm-label { font-size: 0.65rem; text-transform: uppercase; letter-spacing: 0.05em; font-weight: 600; }
</style>
</head>
<body class="text-slate-800">
<div class="container mx-auto p-2 sm:p-4 max-w-full">
<header class="mb-4 text-center">
<h1 class="text-3xl sm:text-4xl font-extrabold text-blue-600">MLP Interactive Visualization</h1>
<p class="text-md text-slate-600 mt-1">Build, train, and visualize a Multi-Layer Perceptron.</p>
</header>
<!-- Main layout with reduced gap -->
<div class="flex flex-col lg:flex-row gap-2.5">
<!-- Control Panel -->
<div class="lg:w-6/12 bg-white p-3.5 rounded-2xl shadow-xl border border-slate-200 space-y-5">
<!-- Data & Training Section -->
<div class="space-y-4">
<h2 class="text-xl font-bold text-slate-700 border-b border-slate-200 pb-2 flex items-center">
1. Data & Training <span class="title-icon">⚙️</span>
</h2>
<div class="grid grid-cols-2 sm:grid-cols-4 gap-x-4 gap-y-3">
<div class="col-span-2 sm:col-span-4">
<label class="text-sm font-medium text-slate-700 block mb-1.5">Data Input Class:</label>
<div class="flex rounded-lg shadow-sm">
<button id="class0Button" class="flex-1 px-3 py-2 border border-slate-300 rounded-l-lg bg-white text-base font-semibold text-slate-700 hover:bg-slate-50 focus:z-10 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 transition-all duration-150" title="Select Class 0 for the next data points">Class 0 (Blue)</button>
<button id="class1Button" class="flex-1 px-3 py-2 border-t border-b border-r border-slate-300 rounded-r-lg bg-white text-base font-semibold text-slate-700 hover:bg-slate-50 focus:z-10 focus:outline-none focus:ring-2 focus:ring-red-500 focus:border-red-500 transition-all duration-150" title="Select Class 1 for the next data points">Class 1 (Red)</button>
</div>
</div>
<!-- Other controls... -->
<div>
<label for="datasetSelect" class="text-sm font-medium text-slate-700 block mb-1">Load Dataset:</label>
<select id="datasetSelect" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 bg-white rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="Choose a preset dataset.">
<option value="manual">Manual Input</option>
<option value="two_moons">Two Moons</option>
<option value="circles">Concentric Circles</option>
<option value="xor">XOR</option>
<option value="spiral">Spiral</option>
</select>
</div>
<div>
<label for="dataNoise" class="text-sm font-medium text-slate-700 block mb-1">Data Noise:</label>
<input type="number" id="dataNoise" value="0.05" min="0" max="0.5" step="0.01" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="Add random noise to the data points.">
</div>
<div>
<label for="learningRate" class="text-sm font-medium text-slate-700 block mb-1">Learning Rate:</label>
<input type="number" id="learningRate" value="0.01" step="0.001" min="0.00001" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="How fast the model learns.">
</div>
<div>
<label for="epochs" class="text-sm font-medium text-slate-700 block mb-1">Epochs:</label>
<input type="number" id="epochs" value="150" step="10" min="1" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="Number of training iterations.">
</div>
<div>
<label for="optimizerSelect" class="text-sm font-medium text-slate-700 block mb-1">Optimizer:</label>
<select id="optimizerSelect" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 bg-white rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="Algorithm to update weights.">
<option value="adam">Adam</option>
<option value="sgd">SGD</option>
<option value="rmsprop">RMSprop</option>
</select>
</div>
<div>
<label for="batchSize" class="text-sm font-medium text-slate-700 block mb-1">Batch Size:</label>
<input type="number" id="batchSize" value="16" min="0" step="4" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="Samples per weight update. 0=Full.">
</div>
<div>
<label for="regularizationTypeSelect" class="text-sm font-medium text-slate-700 block mb-1">Regularization:</label>
<select id="regularizationTypeSelect" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 bg-white rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="Technique to prevent overfitting.">
<option value="none">None</option>
<option value="l1">L1</option>
<option value="l2">L2</option>
</select>
</div>
<div id="regularizationRateContainer" class="hidden">
<label for="regularizationRateInput" class="text-sm font-medium text-slate-700 block mb-1">Reg. Rate (λ):</label>
<input type="number" id="regularizationRateInput" value="0.001" step="0.0001" min="0" class="mt-0 block w-full px-2 py-1.5 border border-slate-300 rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm" title="Strength of the L1/L2 penalty.">
</div>
</div>
</div>
<!-- MLP Architecture Section -->
<div class="space-y-2">
<h2 class="text-xl font-bold text-slate-700 border-b border-slate-200 pb-2 flex items-center">
2. MLP Architecture <span class="title-icon">🧠</span>
</h2>
<div id="networkVisualization" class="bg-slate-50 rounded-lg p-2 min-h-[120px] border border-slate-200"></div>
<h3 class="text-sm font-medium text-slate-700 pt-2">Hidden Layers:</h3>
<div id="hiddenLayersConfigContainer" class="grid grid-cols-1 md:grid-cols-2 gap-2 max-h-48 overflow-y-auto pr-1.5"></div>
<button id="addHiddenLayerButton" class="mt-1 w-full flex justify-center items-center py-2 px-3 border-2 border-dashed border-blue-400 rounded-lg shadow-sm text-sm font-semibold text-blue-600 bg-blue-50 hover:bg-blue-100 hover:border-blue-500 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-blue-500 transition-all duration-150" title="Add a new hidden layer.">
<span class="action-icon text-base"></span> Add Layer
</button>
</div>
<!-- Actions Section -->
<div class="space-y-2">
<h2 class="text-xl font-bold text-slate-700 border-b border-slate-200 pb-2 flex items-center">
3. Actions <span class="title-icon">⚡️</span>
</h2>
<div class="grid grid-cols-2 sm:grid-cols-4 gap-2.5">
<button id="trainButton" class="col-span-2 flex justify-center items-center py-2.5 px-3 border border-transparent rounded-lg shadow-md text-base font-bold text-white bg-green-600 hover:bg-green-700 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-green-500 transition-all duration-150 disabled:opacity-50 disabled:cursor-not-allowed disabled:bg-green-400" title="Start or retrain the model.">
<span class="action-icon">▶️</span> <span id="trainButtonText">Train</span>
</button>
<button id="stopButton" class="hidden col-span-2 flex justify-center items-center py-2.5 px-3 border border-transparent rounded-lg shadow-md text-base font-bold text-white bg-red-600 hover:bg-red-700 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-red-500 transition-all" title="Stop the current training process.">
<span class="action-icon">⏹️</span> <span>Stop</span>
</button>
<button id="resetWeightsButton" class="flex justify-center items-center py-2.5 px-3 border border-slate-300 rounded-lg shadow-sm text-base font-semibold text-slate-700 bg-white hover:bg-slate-50 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-indigo-500 transition-all duration-150 disabled:opacity-50 disabled:cursor-not-allowed" title="Re-initialize the model's weights.">
<span class="action-icon">🔄</span> <span>Reset</span>
</button>
<button id="clearButton" class="flex justify-center items-center py-2.5 px-3 border border-slate-300 rounded-lg shadow-sm text-base font-semibold text-slate-700 bg-white hover:bg-slate-50 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-indigo-500 transition-all duration-150 disabled:opacity-50 disabled:cursor-not-allowed" title="Delete all data and reset settings.">
<span class="action-icon">🗑️</span> <span>Clear All</span>
</button>
</div>
</div>
</div>
<!-- Visualization Area with reduced gap -->
<div class="lg:w-6/12 flex flex-col gap-2.5">
<div class="bg-white p-2 rounded-2xl shadow-xl border border-slate-200">
<canvas id="plotCanvas" class="border border-slate-200 rounded-xl w-full"></canvas>
</div>
<div class="bg-white p-3.5 rounded-2xl shadow-xl border border-slate-200 flex flex-col flex-grow">
<h2 class="text-xl font-bold text-slate-700 mb-2 border-b border-slate-200 pb-2 flex items-center">
Training Status <span class="title-icon">📊</span>
</h2>
<div id="trainingParamsDisplay" class="text-xs mb-2 space-y-1"></div>
<div id="statusMessage" class="text-sm text-slate-600 p-2 bg-slate-50 rounded-lg min-h-[44px] mb-3 whitespace-pre-line leading-snug flex items-center justify-center border border-slate-200"></div>
<div class="flex-grow flex flex-col sm:flex-row gap-4">
<div class="w-full sm:w-2/3 relative min-h-[150px]">
<canvas id="trainingPlotCanvas" class="w-full h-full"></canvas>
</div>
<div id="confusionMatrixContainer" class="w-full sm:w-1/3">
<!-- Confusion Matrix will be rendered here -->
</div>
</div>
</div>
</div>
</div>
</div>
<script>
// --- App Namespace ---
// Encapsulate the entire application in a single object to avoid polluting the global namespace.
// This improves organization and prevents potential conflicts with other scripts.
const App = {
// --- STATE & CONFIG ---
state: {
model: null,
dataPoints: [],
currentClass: 0,
hiddenLayerConfigs: [],
trainingHistory: { loss: [], acc: [] },
isTraining: false,
},
ui: {}, // To hold DOM element references
config: {
pointRadius: 4.5,
classColors: {
0: { point: 'rgba(59, 130, 246, 1)', boundary: 'rgba(59, 130, 246, 0.3)' },
1: { point: 'rgba(239, 68, 68, 1)', boundary: 'rgba(239, 68, 68, 0.3)' }
},
statusIcons: {
info: 'ℹ️', success: '✅', warning: '⚠️', error: '❌', loading: '⏳'
},
},
// --- INITIALIZATION ---
async init() {
this.cacheUIElements();
this.registerEventListeners();
await tf.ready();
try {
await tf.setBackend('cpu');
console.log("TensorFlow.js backend set to CPU.");
} catch (e) { console.warn("Could not set TF.js backend to CPU.", e); }
this.methods.resizePlotCanvas();
this.initializeApplicationState();
},
cacheUIElements() {
const ids = [
'class0Button', 'class1Button', 'datasetSelect', 'dataNoise', 'batchSize',
'hiddenLayersConfigContainer', 'addHiddenLayerButton', 'networkVisualization',
'optimizerSelect', 'learningRate', 'epochs', 'regularizationTypeSelect',
'regularizationRateContainer', 'regularizationRateInput', 'trainButton',
'stopButton', 'resetWeightsButton', 'clearButton', 'statusMessage',
'plotCanvas', 'trainingPlotCanvas', 'trainingParamsDisplay', 'trainButtonText',
'confusionMatrixContainer'
];
ids.forEach(id => { this.ui[id] = document.getElementById(id); });
this.ui.canvas = this.ui.plotCanvas;
this.ui.ctx = this.ui.canvas.getContext('2d');
},
registerEventListeners() {
window.addEventListener('resize', () => {
this.methods.resizePlotCanvas();
this.methods.drawTrainingPlot();
this.methods.drawNetworkVisualization();
});
this.ui.canvas.addEventListener('click', (e) => this.methods.handleCanvasClick(e));
this.ui.class0Button.addEventListener('click', () => this.methods.setActiveClass(0));
this.ui.class1Button.addEventListener('click', () => this.methods.setActiveClass(1));
this.ui.addHiddenLayerButton.addEventListener('click', () => this.methods.addHiddenLayerUI());
this.ui.trainButton.addEventListener('click', () => this.methods.trainAndVisualize());
this.ui.stopButton.addEventListener('click', () => this.methods.stopTraining());
this.ui.resetWeightsButton.addEventListener('click', () => this.methods.resetModelWeights());
this.ui.clearButton.addEventListener('click', () => this.methods.clearAllDataAndReseed());
this.ui.datasetSelect.addEventListener('change', () => this.methods.loadSelectedDataset());
this.ui.dataNoise.addEventListener('input', () => { if (this.ui.datasetSelect.value !== 'manual') this.methods.loadSelectedDataset(); });
this.ui.regularizationTypeSelect.addEventListener('change', () => this.methods.toggleRegularizationRateInput());
},
initializeApplicationState() {
this.methods.setActiveClass(0);
this.methods.addHiddenLayerUI(8, 'relu'); // Default architecture
this.methods.addHiddenLayerUI(4, 'relu');
this.methods.toggleRegularizationRateInput();
this.methods.drawAll();
this.methods.drawTrainingPlot();
this.methods.updateTrainingParamsDisplay();
this.methods.updateButtonStates();
this.methods.renderConfusionMatrix(null);
this.methods.updateStatus('Ready. Click canvas to add points or load a dataset.', 'info');
},
// --- METHODS (Logic & Handlers) ---
methods: {
// --- UI & State Management ---
resizePlotCanvas() {
const dpr = window.devicePixelRatio || 1;
const rect = App.ui.canvas.parentElement.getBoundingClientRect();
if (rect.width === 0) return;
App.ui.canvas.width = rect.width * dpr;
const newHeight = Math.min(rect.width * 0.85, Math.max(300, window.innerHeight * 0.5));
App.ui.canvas.height = newHeight * dpr;
App.ui.ctx.scale(dpr, dpr);
App.ui.canvas.style.width = `${rect.width}px`;
App.ui.canvas.style.height = `${newHeight}px`;
App.ui.canvasWidth = rect.width;
App.ui.canvasHeight = newHeight;
this.drawAll();
},
setActiveClass(classNum) {
App.state.currentClass = classNum;
App.ui.class0Button.classList.toggle('button-active-blue', classNum === 0);
App.ui.class1Button.classList.toggle('button-active-red', classNum === 1);
App.ui.plotCanvas.className = App.ui.plotCanvas.className.replace(/cursor-class-\d/, '');
App.ui.plotCanvas.classList.add(`cursor-class-${classNum}`);
},
toggleRegularizationRateInput() {
App.ui.regularizationRateContainer.classList.toggle('hidden', App.ui.regularizationTypeSelect.value === 'none');
},
updateStatus(message, type = 'info') {
const icon = App.config.statusIcons[type] || '';
const loadingClass = type === 'loading' ? 'loading-icon' : '';
const iconHtml = `<span class="status-icon ${loadingClass}">${icon}</span>`;
App.ui.statusMessage.innerHTML = `${iconHtml}<span>${message}</span>`;
const typeToColor = {
info: 'text-slate-600 bg-slate-50 border-slate-200',
success: 'text-green-700 bg-green-50 border-green-200 font-semibold',
warning: 'text-amber-700 bg-amber-50 border-amber-200 font-semibold',
error: 'text-red-700 bg-red-50 border-red-200 font-semibold',
loading: 'text-blue-700 bg-blue-50 border-blue-200'
};
App.ui.statusMessage.className = `text-sm p-2 rounded-lg min-h-[44px] whitespace-pre-line leading-snug flex items-center justify-center border ${typeToColor[type] || ''}`;
},
setTrainingState(training) {
App.state.isTraining = training;
App.ui.trainButton.classList.toggle('hidden', training);
App.ui.stopButton.classList.toggle('hidden', !training);
this.updateButtonStates();
},
updateButtonStates() {
const hasData = App.state.dataPoints.length > 0;
const hasModel = App.state.model != null;
App.ui.trainButton.disabled = App.state.isTraining || !hasData;
App.ui.clearButton.disabled = App.state.isTraining;
App.ui.resetWeightsButton.disabled = App.state.isTraining || !hasModel;
if (hasModel && App.state.trainingHistory.loss.length > 0 && !App.state.isTraining) {
App.ui.trainButtonText.textContent = 'Retrain';
} else {
App.ui.trainButtonText.textContent = 'Train';
}
},
// --- Architecture UI ---
addHiddenLayerUI(defaultNeurons = 8, defaultActivation = 'relu') {
if (App.state.hiddenLayerConfigs.length >= 8) {
this.updateStatus("Max 8 hidden layers reached.", 'warning');
return;
}
const layerIndex = App.state.hiddenLayerConfigs.length;
App.state.hiddenLayerConfigs.push({ neurons: parseInt(defaultNeurons), activation: defaultActivation });
const layerDiv = document.createElement('div');
layerDiv.className = 'layer-config-item bg-slate-50 border border-slate-200 p-2 rounded-lg flex items-center justify-between gap-2 hover:bg-slate-100 transition-colors';
layerDiv.dataset.index = layerIndex;
layerDiv.innerHTML = `
<div class="flex-grow flex items-center gap-x-3">
<div class="flex items-center">
<label for="neurons_${layerIndex}" class="text-sm font-medium text-slate-600 mr-2 whitespace-nowrap">L${layerIndex + 1} Units:</label>
<input type="number" id="neurons_${layerIndex}" value="${defaultNeurons}" min="1" max="64" step="1" class="w-16 px-2 py-1 border border-slate-300 rounded-md shadow-sm text-sm focus:ring-blue-500 focus:border-blue-500">
</div>
<div class="flex items-center">
<label for="activation_${layerIndex}" class="text-sm font-medium text-slate-600 mr-2">Act:</label>
<select id="activation_${layerIndex}" class="flex-1 min-w-[90px] px-2 py-1 border bg-white border-slate-300 rounded-md shadow-sm text-sm focus:ring-blue-500 focus:border-blue-500">
${['relu', 'sigmoid', 'tanh', 'leakyRelu'].map(act => `<option value="${act}" ${act === defaultActivation ? 'selected' : ''}>${act.charAt(0).toUpperCase() + act.slice(1)}</option>`).join('')}
</select>
</div>
</div>
<button title="Remove layer" class="remove-btn flex-shrink-0 p-1.5 text-red-500 hover:text-red-700 rounded-full hover:bg-red-100 focus:outline-none focus:ring-2 focus:ring-red-500 focus:ring-offset-1 transition-all duration-150">
<span class="text-base font-bold">⛔</span>
</button>
`;
App.ui.hiddenLayersConfigContainer.appendChild(layerDiv);
layerDiv.querySelector(`#neurons_${layerIndex}`).onchange = (e) => {
const value = parseInt(e.target.value);
App.state.hiddenLayerConfigs[layerIndex].neurons = !isNaN(value) && value > 0 ? value : 1;
e.target.value = App.state.hiddenLayerConfigs[layerIndex].neurons;
this.drawNetworkVisualization();
};
layerDiv.querySelector(`#activation_${layerIndex}`).onchange = (e) => {
App.state.hiddenLayerConfigs[layerIndex].activation = e.target.value;
this.drawNetworkVisualization();
};
layerDiv.querySelector('.remove-btn').onclick = () => {
App.state.hiddenLayerConfigs.splice(layerIndex, 1);
this.redrawLayerConfigs();
};
this.drawNetworkVisualization();
},
redrawLayerConfigs() {
const configs = [...App.state.hiddenLayerConfigs];
App.ui.hiddenLayersConfigContainer.innerHTML = '';
App.state.hiddenLayerConfigs = [];
configs.forEach(config => this.addHiddenLayerUI(config.neurons, config.activation));
},
// --- Data Handling ---
handleCanvasClick(event) {
if (!App.ui.canvas) return;
const rect = App.ui.canvas.getBoundingClientRect();
const x = event.clientX - rect.left;
const y = event.clientY - rect.top;
App.state.dataPoints.push({ x, y, normX: x / App.ui.canvasWidth, normY: y / App.ui.canvasHeight, label: App.state.currentClass });
this.drawPoints();
this.updateStatus(`Added Class ${App.state.currentClass} point. Total: ${App.state.dataPoints.length}.`, 'info');
App.ui.datasetSelect.value = "manual";
this.updateButtonStates();
},
loadSelectedDataset() {
if (App.state.model) { App.state.model.dispose(); App.state.model = null; }
tf.disposeVariables();
App.state.dataPoints = [];
const datasetName = App.ui.datasetSelect.value;
if (datasetName === 'manual') {
this.drawAll();
this.updateButtonStates();
return;
}
const noise = parseFloat(App.ui.dataNoise.value) || 0;
const nSamples = 150;
const generators = {
two_moons: (n, noise) => {
const n_per_moon = Math.floor(n / 2);
const radius = 0.3;
for (let i = 0; i < n_per_moon; i++) {
const angle = (i / n_per_moon) * Math.PI;
// First moon, shifted left and up
this.addDataPoint(
0.5 + radius * Math.cos(angle) - 0.125,
0.5 + radius * Math.sin(angle) + 0.1,
0, noise
);
// Second moon, shifted right and down
this.addDataPoint(
0.5 + radius * Math.cos(angle + Math.PI) + 0.125,
0.5 + radius * Math.sin(angle + Math.PI) - 0.1,
1, noise
);
}
},
circles: (n, noise) => { for (let i=0; i<n; i++) { const r=Math.random(), a=Math.random()*2*Math.PI, l=r<0.5?0:1, rs=l===0?0.2:0.4; this.addDataPoint(0.5+rs*Math.cos(a), 0.5+rs*Math.sin(a), l, noise); } },
xor: (n, noise) => { const q=Math.floor(n/4), s=0.3; for (let i=0; i<q; i++) { this.addDataPoint(0.5-s, 0.5-s, 0, noise*2); this.addDataPoint(0.5+s, 0.5-s, 1, noise*2); this.addDataPoint(0.5-s, 0.5+s, 1, noise*2); this.addDataPoint(0.5+s, 0.5+s, 0, noise*2); } },
spiral: (n, noise) => { const pts=Math.floor(n/2); for (let i=0; i<pts; i++) { const a=i/20*Math.PI, r=0.05+i/pts*0.4; this.addDataPoint(0.5+r*Math.cos(a), 0.5+r*Math.sin(a), 0, noise*0.5); this.addDataPoint(0.5+r*Math.cos(a+Math.PI), 0.5+r*Math.sin(a+Math.PI), 1, noise*0.5); } }
};
generators[datasetName](nSamples, noise);
App.state.dataPoints.forEach(p => { p.x = p.normX * App.ui.canvasWidth; p.y = p.normY * App.ui.canvasHeight; });
this.drawAll();
this.updateStatus(`Loaded '${datasetName}' dataset. Noise: ${noise}.`, 'info');
App.state.trainingHistory = { loss: [], acc: [] };
this.drawTrainingPlot();
this.renderConfusionMatrix(null);
this.updateButtonStates();
this.updateTrainingParamsDisplay();
},
addDataPoint(normX, normY, label, noise) {
App.state.dataPoints.push({
x: 0, y: 0,
normX: normX + (Math.random() - 0.5) * noise,
normY: normY + (Math.random() - 0.5) * noise,
label: label
});
},
// --- Drawing & Visualization ---
drawAll(boundaryGrid = null) {
if (!App.ui.ctx || !App.ui.canvasWidth || !App.ui.canvasHeight) return;
App.ui.ctx.clearRect(0, 0, App.ui.canvasWidth, App.ui.canvasHeight);
if (boundaryGrid) {
const { grid, resolution } = boundaryGrid;
for (let i = 0; i < grid.length; i++) {
for (let j = 0; j < grid[i].length; j++) {
App.ui.ctx.fillStyle = App.config.classColors[grid[i][j]].boundary;
App.ui.ctx.fillRect(j * resolution, i * resolution, resolution, resolution);
}
}
}
App.state.dataPoints.forEach(point => {
App.ui.ctx.beginPath();
App.ui.ctx.arc(point.x, point.y, App.config.pointRadius, 0, 2 * Math.PI);
App.ui.ctx.fillStyle = App.config.classColors[point.label].point;
App.ui.ctx.fill();
App.ui.ctx.strokeStyle = 'rgba(255,255,255,0.7)';
App.ui.ctx.lineWidth = 1.5;
App.ui.ctx.stroke();
});
},
drawPoints() { this.drawAll(); },
drawNetworkVisualization() {
const container = App.ui.networkVisualization;
container.innerHTML = '';
const allLayers = [
{ type: 'input', neurons: 2, activation: 'Input' },
...App.state.hiddenLayerConfigs.map((cfg, i) => ({ type: `hidden-${i % 5}`, neurons: cfg.neurons, activation: cfg.activation })),
{ type: 'output', neurons: 1, activation: 'Sigmoid' }
];
const svg = document.createElementNS("http://www.w3.org/2000/svg", "svg");
const rect = container.getBoundingClientRect();
if(rect.width === 0 || rect.height === 0) return;
svg.setAttribute('viewBox', `0 0 ${rect.width} ${rect.height}`);
const margin = { top: 25, right: 15, bottom: 20, left: 15 };
const width = rect.width - margin.left - margin.right;
const height = rect.height - margin.top - margin.bottom;
const layerSpacing = allLayers.length > 1 ? width / (allLayers.length - 1) : width;
// Connections
for (let i = 0; i < allLayers.length - 1; i++) {
const x1 = margin.left + i * layerSpacing;
const x2 = margin.left + (i + 1) * layerSpacing;
const maxNeurons = 8;
const currentNeurons = Math.min(allLayers[i].neurons, maxNeurons);
const nextNeurons = Math.min(allLayers[i+1].neurons, maxNeurons);
for (let j = 0; j < currentNeurons; j++) {
const y1 = margin.top + height * ((j + 0.5) / currentNeurons);
for (let k = 0; k < nextNeurons; k++) {
const y2 = margin.top + height * ((k + 0.5) / nextNeurons);
const line = document.createElementNS("http://www.w3.org/2000/svg", "line");
line.setAttribute('x1', x1); line.setAttribute('y1', y1);
line.setAttribute('x2', x2); line.setAttribute('y2', y2);
line.setAttribute('class', 'connection');
svg.appendChild(line);
}
}
}
// Neurons and Labels
allLayers.forEach((layer, i) => {
const x = margin.left + i * layerSpacing;
const maxNeurons = 8;
const displayNeurons = Math.min(layer.neurons, maxNeurons);
const neuronRadius = Math.max(3, Math.min(7, height / (displayNeurons * 2.5)));
for (let j = 0; j < displayNeurons; j++) {
const y = margin.top + height * ((j + 0.5) / displayNeurons);
const circle = document.createElementNS("http://www.w3.org/2000/svg", "circle");
circle.setAttribute('cx', x); circle.setAttribute('cy', y);
circle.setAttribute('r', neuronRadius);
circle.setAttribute('class', `neuron ${layer.type}`);
svg.appendChild(circle);
}
const labelText = layer.activation.charAt(0).toUpperCase() + layer.activation.slice(1);
const textLabel = document.createElementNS("http://www.w3.org/2000/svg", "text");
textLabel.setAttribute('x', x); textLabel.setAttribute('y', margin.top - 8);
textLabel.setAttribute('class', 'layer-label');
textLabel.textContent = labelText;
svg.appendChild(textLabel);
const countLabel = document.createElementNS("http://www.w3.org/2000/svg", "text");
countLabel.setAttribute('x', x); countLabel.setAttribute('y', margin.top + height + 15);
countLabel.setAttribute('class', 'neuron-count-label');
countLabel.textContent = `${layer.neurons} N`;
svg.appendChild(countLabel);
});
container.appendChild(svg);
},
drawTrainingPlot() {
const canvas = App.ui.trainingPlotCanvas;
const ctx = canvas.getContext('2d');
const dpr = window.devicePixelRatio || 1;
const rect = canvas.getBoundingClientRect();
if (rect.width === 0 || rect.height === 0) return;
canvas.width = rect.width * dpr;
canvas.height = rect.height * dpr;
ctx.scale(dpr, dpr);
const { width, height } = rect;
const padding = {top: 20, right: 15, bottom: 20, left: 30};
ctx.fillStyle = '#f8fafc';
ctx.fillRect(0,0,width,height);
if (App.state.trainingHistory.loss.length === 0) {
ctx.fillStyle = '#64748b';
ctx.textAlign = 'center';
ctx.font = '12px Inter';
ctx.fillText('Training history will be plotted here.', width / 2, height / 2);
return;
}
ctx.beginPath();
ctx.strokeStyle = '#e2e8f0';
ctx.lineWidth = 1;
for(let i = 0; i <= 4; i++){
const y = padding.top + i * (height - padding.top - padding.bottom) / 4;
ctx.moveTo(padding.left, y);
ctx.lineTo(width-padding.right, y);
}
ctx.stroke();
ctx.font = '10px Inter';
ctx.fillStyle = '#475569';
ctx.textAlign = 'right';
for(let i = 0; i <= 4; i++){
ctx.fillText((1 - i/4).toFixed(1), padding.left - 6, padding.top + 3 + i * (height - padding.top - padding.bottom) / 4);
}
const plotData = (data, color) => {
ctx.beginPath(); ctx.strokeStyle = color; ctx.lineWidth = 2; ctx.lineJoin = 'round'; ctx.lineCap = 'round';
data.forEach((val, i) => {
const x = padding.left + (i / (Math.max(1, data.length -1))) * (width - padding.left - padding.right);
const y = (height - padding.bottom) - Math.min(Math.max(val,0.0),1.0) * (height - padding.top - padding.bottom);
if (i === 0) ctx.moveTo(x, y); else ctx.lineTo(x, y);
});
ctx.stroke();
};
plotData(App.state.trainingHistory.loss, 'rgba(239, 68, 68, 0.9)');
plotData(App.state.trainingHistory.acc, 'rgba(37, 99, 235, 0.9)');
ctx.textAlign = 'left';
ctx.font = '600 11px Inter';
ctx.fillStyle = 'rgba(239, 68, 68, 1)'; ctx.fillRect(padding.left + 5, 5, 10, 3);
ctx.fillStyle = '#374151'; ctx.fillText('Loss', padding.left + 20, 10);
ctx.fillStyle = 'rgba(37, 99, 235, 1)'; ctx.fillRect(padding.left + 75, 5, 10, 3);
ctx.fillStyle = '#374151'; ctx.fillText('Accuracy', padding.left + 90, 10);
},
updateTrainingParamsDisplay() {
const container = App.ui.trainingParamsDisplay;
if (!App.state.model) {
container.innerHTML = `<div class="text-slate-500">Train a model to see parameters.</div>`;
return;
}
const lr = parseFloat(App.ui.learningRate.value);
const regType = App.ui.regularizationTypeSelect.value;
const regRate = parseFloat(App.ui.regularizationRateInput.value) || 0;
let regDesc = regType !== 'none' ? `${regType.toUpperCase()}(λ=${regRate})` : 'None';
let hiddenDesc = App.state.hiddenLayerConfigs.map(l => l.neurons).join(' → ');
container.innerHTML = `
<div class="flex flex-wrap gap-x-4 gap-y-1">
<span><span class="font-semibold text-slate-500">Opt:</span> <span class="font-medium text-slate-800">${App.ui.optimizerSelect.value.toUpperCase()}</span></span>
<span><span class="font-semibold text-slate-500">LR:</span> <span class="font-medium text-slate-800">${lr.toExponential(1)}</span></span>
<span><span class="font-semibold text-slate-500">Batch:</span> <span class="font-medium text-slate-800">${parseInt(App.ui.batchSize.value) === 0 ? 'Full' : App.ui.batchSize.value}</span></span>
<span><span class="font-semibold text-slate-500">Reg:</span> <span class="font-medium text-slate-800">${regDesc}</span></span>
</div>
<div><span class="font-semibold text-slate-500">Layers:</span> <span class="font-medium text-slate-800">2 → ${hiddenDesc || '...'} → 1</span></div>
`;
},
renderConfusionMatrix(matrix) {
const container = App.ui.confusionMatrixContainer;
if (!matrix) {
container.innerHTML = `<div class="flex items-center justify-center h-full text-sm text-center text-slate-400 bg-slate-50 rounded-lg p-2 border border-slate-200">Confusion matrix appears here after training.</div>`;
return;
}
container.innerHTML = `
<div class="h-full flex flex-col">
<h4 class="text-sm font-semibold text-center text-slate-600 mb-1.5">Confusion Matrix</h4>
<div class="grid grid-cols-2 grid-rows-2 gap-1.5 flex-grow">
<div class="cm-cell bg-blue-100 text-blue-800" title="True Negative"><span class="cm-value">${matrix.tn}</span><span class="cm-label">TN</span></div>
<div class="cm-cell bg-red-100 text-red-800" title="False Positive"><span class="cm-value">${matrix.fp}</span><span class="cm-label">FP</span></div>
<div class="cm-cell bg-red-100 text-red-800" title="False Negative"><span class="cm-value">${matrix.fn}</span><span class="cm-label">FN</span></div>
<div class="cm-cell bg-green-100 text-green-800" title="True Positive"><span class="cm-value">${matrix.tp}</span><span class="cm-label">TP</span></div>
</div>
</div>
`;
},
// --- TENSORFLOW.JS & ML ---
getSafeNumericInput(element, defaultValue, isInteger = true) {
let value = isInteger ? parseInt(element.value, 10) : parseFloat(element.value);
if (isNaN(value)) {
value = defaultValue;
element.value = defaultValue;
}
return value;
},
buildModel() {
if (App.state.model) { App.state.model.dispose(); App.state.model = null; }
tf.disposeVariables();
const learningRate = this.getSafeNumericInput(App.ui.learningRate, 0.01, false);
const regType = App.ui.regularizationTypeSelect.value;
const regRate = this.getSafeNumericInput(App.ui.regularizationRateInput, 0, false);
const kernelRegularizer = (regType !== 'none' && regRate > 0)
? tf.regularizers[regType]({[regType]: regRate})
: null;
App.state.model = tf.sequential();
const inputShape = [2];
// Add hidden layers
App.state.hiddenLayerConfigs.forEach((layerConfig, index) => {
App.state.model.add(tf.layers.dense({
units: layerConfig.neurons,
inputShape: index === 0 ? inputShape : undefined,
activation: layerConfig.activation,
kernelRegularizer
}));
});
// Add output layer
App.state.model.add(tf.layers.dense({
units: 1,
activation: 'sigmoid',
inputShape: App.state.hiddenLayerConfigs.length === 0 ? inputShape : undefined,
}));
const optimizerInstance = tf.train[App.ui.optimizerSelect.value](learningRate);
App.state.model.compile({ optimizer: optimizerInstance, loss: 'binaryCrossentropy', metrics: ['accuracy'] });
App.state.model.stopTraining = false;
},
async trainAndVisualize() {
if (App.state.isTraining) return;
const uniqueLabels = new Set(App.state.dataPoints.map(p => p.label));
if (App.state.dataPoints.length < 4 || uniqueLabels.size < 2) {
this.updateStatus('Requires at least 4 points and data from both classes.', 'error');
return;
}
this.setTrainingState(true);
this.updateStatus('Starting training...', 'loading');
this.renderConfusionMatrix(null);
await tf.nextFrame();
this.buildModel();
this.updateTrainingParamsDisplay();
const epochs = this.getSafeNumericInput(App.ui.epochs, 150, true);
let batchSize = this.getSafeNumericInput(App.ui.batchSize, 16, true);
if (batchSize === 0) batchSize = App.state.dataPoints.length;
const [xs, ys] = tf.tidy(() => {
const normalized = App.state.dataPoints.map(p => [p.normX, p.normY]);
const labels = App.state.dataPoints.map(p => p.label);
return [tf.tensor2d(normalized), tf.tensor2d(labels, [labels.length, 1])];
});
App.state.trainingHistory = { loss: [], acc: [] };
this.drawTrainingPlot();
try {
await App.state.model.fit(xs, ys, {
epochs, batchSize,
callbacks: {
onEpochEnd: async (epoch, logs) => {
if (App.state.model.stopTraining) { App.state.model.stop(); return; }
this.updateStatus(`Epoch ${epoch + 1}/${epochs} - Loss: ${logs.loss.toFixed(4)}, Acc: ${logs.acc.toFixed(4)}`, 'loading');
App.state.trainingHistory.loss.push(logs.loss);
App.state.trainingHistory.acc.push(logs.acc);
this.drawTrainingPlot();
if ((epoch + 1) % Math.max(1, Math.floor(epochs / 25)) === 0) {
const boundaryGrid = await this.generateBoundaryGrid();
this.drawAll(boundaryGrid);
}
await tf.nextFrame();
},
onTrainEnd: async () => {
const finalAcc = App.state.trainingHistory.acc.slice(-1)[0] || 0;
if (!App.state.model.stopTraining) {
this.updateStatus(`Training complete! Final Accuracy: <b>${(finalAcc*100).toFixed(2)}%</b>`, 'success');
}
const boundaryGrid = await this.generateBoundaryGrid();
this.drawAll(boundaryGrid);
const confusionMatrix = await this.calculateConfusionMatrix(xs, ys);
this.renderConfusionMatrix(confusionMatrix);
this.setTrainingState(false);
}
}
});
} catch (error) {
console.error("Training error:", error);
this.updateStatus(`Training error: ${error.message}`, 'error');
this.setTrainingState(false);
} finally {
xs.dispose();
ys.dispose();
}
},
async calculateConfusionMatrix(xs, ys) {
if (!App.state.model) return null;
return tf.tidy(() => {
const predictions = App.state.model.predict(xs).round();
const tp = ys.mul(predictions).sum().dataSync()[0];
const tn = ys.sub(1).mul(-1).mul(predictions.sub(1).mul(-1)).sum().dataSync()[0];
const fp = predictions.sub(ys).relu().sum().dataSync()[0];
const fn = ys.sub(predictions).relu().sum().dataSync()[0];
return { tp, tn, fp, fn };
});
},
async generateBoundaryGrid() {
if (!App.state.model || !App.ui.canvasWidth || !App.ui.canvasHeight) return null;
const resolution = Math.max(5, Math.floor(App.ui.canvasWidth / 80));
const numCols = Math.floor(App.ui.canvasWidth / resolution);
const numRows = Math.floor(App.ui.canvasHeight / resolution);
if (numCols <= 0 || numRows <= 0) return null;
const boundaryData = tf.tidy(() => {
const gridPoints = [];
for (let i = 0; i < numRows; i++) {
for (let j = 0; j < numCols; j++) {
gridPoints.push([(j * resolution) / App.ui.canvasWidth, (i * resolution) / App.ui.canvasHeight]);
}
}
const predsTensor = App.state.model.predict(tf.tensor2d(gridPoints));
return predsTensor.dataSync(); // Use dataSync inside tidy
});
const grid = []; let k = 0;
for (let i = 0; i < numRows; i++) {
const row = [];
for (let j = 0; j < numCols; j++) row.push(boundaryData[k++] > 0.5 ? 1 : 0);
grid.push(row);
}
return { grid, resolution };
},
stopTraining() {
if (App.state.model) {
App.state.model.stopTraining = true;
this.updateStatus("Training stopped by user.", 'warning');
}
},
resetModelWeights() {
if (App.state.isTraining) return;
this.updateStatus("Model weights reset. Ready to train again.", 'info');
this.buildModel();
this.drawAll();
App.state.trainingHistory = { loss: [], acc: [] };
this.drawTrainingPlot();
this.renderConfusionMatrix(null);
this.updateButtonStates();
},
clearAllDataAndReseed() {
if (App.state.isTraining) return;
App.state.dataPoints = [];
if(App.state.model) { App.state.model.dispose(); App.state.model = null; }
tf.disposeVariables();
App.state.hiddenLayerConfigs = [];
App.ui.hiddenLayersConfigContainer.innerHTML = '';
App.initializeApplicationState();
this.updateStatus('Cleared all data and reset configuration.', 'info');
}
}
};
// --- Entry Point ---
document.addEventListener('DOMContentLoaded', () => App.init());
</script>
</body>
</html>