pytorch / pages /2_LinearRegression.py
eaglelandsonce's picture
Update pages/2_LinearRegression.py
de81b4d verified
raw
history blame
1.42 kB
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
# Set a seed for reproducibility
torch.manual_seed(59)
# Define the Linear Model
class LinearModel(nn.Module):
def __init__(self, in_features, out_features):
super(LinearModel, self).__init__()
self.linear = nn.Linear(in_features, out_features)
def forward(self, x):
return self.linear(x)
# Instantiate the model
model = LinearModel(1, 1)
# Print model weight and bias
print(f'Model weight: {model.linear.weight.item()}')
print(f'Model bias: {model.linear.bias.item()}')
# Streamlit app title
st.title('Interactive Scatter Plot with Noise and Number of Data Points')
# Sidebar sliders for noise and number of data points
noise_level = st.sidebar.slider('Noise Level', 0.0, 1.0, 0.1, step=0.01)
num_points = st.sidebar.slider('Number of Data Points', 10, 100, 50, step=5)
# Generate data
np.random.seed(59)
x = np.linspace(0, 10, num_points).reshape(-1, 1).astype(np.float32)
with torch.no_grad():
x_tensor = torch.tensor(x)
y_tensor = model(x_tensor)
y = y_tensor.numpy().flatten() + noise_level * np.random.randn(num_points)
# Create scatter plot
fig, ax = plt.subplots()
ax.scatter(x, y, alpha=0.6)
ax.set_title('Scatter Plot with Noise and Number of Data Points')
ax.set_xlabel('X-axis')
ax.set_ylabel('Y-axis')
# Display plot in Streamlit
st.pyplot(fig)