Spaces:
Sleeping
Sleeping
import streamlit as st | |
import webbrowser | |
import pandas as pd | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
import plotly.express as px | |
import pickle | |
import base64 | |
import io | |
import plotly.graph_objects as go | |
#import viz_report | |
import viz_ai_img | |
import word_cloud | |
import notepad_lite | |
import calculator | |
# Import ML libraries | |
from sklearn.model_selection import train_test_split | |
from sklearn.preprocessing import LabelEncoder, StandardScaler | |
from sklearn.linear_model import LinearRegression, LogisticRegression | |
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor | |
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor, GradientBoostingClassifier, GradientBoostingRegressor | |
from sklearn.svm import SVC, SVR | |
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor | |
from sklearn.naive_bayes import GaussianNB # For classification | |
from sklearn.metrics import mean_squared_error, r2_score, accuracy_score, precision_score, recall_score, f1_score, confusion_matrix | |
import numpy as np # For numerical operations, especially with metrics | |
st.set_page_config("Visio AI", page_icon="images/favicon.png", layout='wide') | |
st.markdown("<h1 style='text-align: center; color: #4A90E2;'>📊 VISIO AI</h1>", unsafe_allow_html=True) | |
st.markdown("<h4 style='text-align: center; color: orange;'>Machine Learning and Data Analysis Platform</h4>", unsafe_allow_html=True) | |
st.markdown("<hr>", unsafe_allow_html=True) | |
#-------------------------------------------------# | |
# --- Session State Initialization --- | |
if 'updated_df' not in st.session_state: | |
st.session_state.updated_df = None | |
if 'original_df_uploaded' not in st.session_state: | |
st.session_state.original_df_uploaded = False | |
if 'last_uploaded_file_name' not in st.session_state: | |
st.session_state.last_uploaded_file_name = None | |
if 'X_train' not in st.session_state: | |
st.session_state.X_train = None | |
if 'X_test' not in st.session_state: | |
st.session_state.X_test = None | |
if 'y_train' not in st.session_state: | |
st.session_state.y_train = None | |
if 'y_test' not in st.session_state: | |
st.session_state.y_test = None | |
if 'target_column' not in st.session_state: | |
st.session_state.target_column = None | |
if 'feature_columns' not in st.session_state: | |
st.session_state.feature_columns = None | |
if 'problem_type' not in st.session_state: | |
st.session_state.problem_type = None # 'classification' or 'regression' | |
if 'trained_model' not in st.session_state: | |
st.session_state.trained_model = None | |
if 'model_metrics' not in st.session_state: | |
st.session_state.model_metrics = None | |
if 'scaler' not in st.session_state: | |
st.session_state.scaler = None | |
# Navigation Bar | |
col1, col2, col3, col4, col5 = st.columns((1, 1, 1, 1, 1)) | |
with col1: | |
about_url = "https://jaiho-labs.onrender.com/pages/products_resources/docs/visio_ai_docs/visio_about.html" | |
if st.button('About'): | |
st.markdown("check out this [link](%s)" % about_url) | |
#webbrowser.open_new_tab(about_url) | |
with col2: | |
guide_url = "https://jaiho-labs.onrender.com/pages/products_resources/docs/visio_ai_docs/visio_helper.html" | |
if st.button('Guide'): | |
st.markdown("check out this [link](%s)" % guide_url) | |
with col3: | |
docs_url = "https://jaiho-labs.onrender.com/pages/products_resources/docs/visio_ai_docs/visio_docs.html" | |
if st.button('Docs'): | |
st.markdown("check out this [link](%s)" % docs_url) | |
with col4: | |
joinus_url = "https://jaiho-labs.onrender.com/pages/products_resources/docs/visio_ai_docs/visio_join.html" | |
if st.button('Join Us'): | |
st.markdown("check out this [link](%s)" % joinus_url) | |
with col5: | |
elite_access = "https://jaiho-labs.onrender.com/pages/products_resources/docs/visio_ai_docs/get_elite_access.html" | |
if st.button('Get Elite Access'): | |
st.markdown("check out this [link](%s)" % elite_access) | |
#-------------------------------------------------# | |
# Top Expander Columns (Data Operations & Algorithms, Select Plot Type, Pre Analysis) | |
col11, col12, col13 = st.columns([1, 1, 1]) | |
# --- Data Operations & Algorithms Expander --- | |
with col11: | |
with st.expander("⚙️ Data Operations & Algorithms", expanded=False): | |
if st.session_state.updated_df is not None: | |
st.markdown("#### 1. Define Target Variable and Problem Type") | |
all_columns = st.session_state.updated_df.columns.tolist() | |
target_column = st.selectbox("Select your **Target Column (Y)**:", ["--- Select ---"] + all_columns, key="target_col_select") | |
if target_column != "--- Select ---": | |
st.session_state.target_column = target_column | |
# Heuristic to guess problem type | |
if st.session_state.updated_df[target_column].dtype in ['int64', 'float64']: | |
if st.session_state.updated_df[target_column].nunique() < 20 and st.session_state.updated_df[target_column].dtype == 'int64': | |
st.session_state.problem_type = 'classification' | |
st.info(f"Detected **Classification** problem based on target column '{target_column}' (integer with few unique values).") | |
else: | |
st.session_state.problem_type = 'regression' | |
st.info(f"Detected **Regression** problem based on target column '{target_column}' (numerical).") | |
elif st.session_state.updated_df[target_column].dtype == 'object' or st.session_state.updated_df[target_column].dtype == 'bool': | |
st.session_state.problem_type = 'classification' | |
st.info(f"Detected **Classification** problem based on target column '{target_column}' (categorical).") | |
else: | |
st.session_state.problem_type = None | |
st.warning("Could not definitively determine problem type. Please proceed with caution.") | |
st.markdown("---") | |
st.markdown("#### 2. Select Independent Variables (Features)") | |
available_features = [col for col in all_columns if col != target_column] | |
feature_columns = st.multiselect("Select your **Independent Variables (X)**:", available_features, default=available_features, key="feature_select") | |
if feature_columns: | |
st.session_state.feature_columns = feature_columns | |
st.markdown("---") | |
st.markdown("#### 3. Split Data into Train and Test Sets") | |
test_size = st.slider("Select Test Set Size:", min_value=0.1, max_value=0.5, value=0.2, step=0.05, key="test_size_slider") | |
random_state = st.number_input("Random State (for reproducibility):", value=42, step=1, key="random_state_input") | |
# Use only selected features | |
features = st.session_state.updated_df[feature_columns] | |
target = st.session_state.updated_df[target_column] | |
# Handle categorical features by encoding | |
for col in features.select_dtypes(include=['object', 'bool']).columns: | |
le = LabelEncoder() | |
features[col] = le.fit_transform(features[col].astype(str)) | |
# Handle numerical features by scaling | |
numerical_cols = features.select_dtypes(include=['number']).columns | |
if not numerical_cols.empty: | |
scaler = StandardScaler() | |
features[numerical_cols] = scaler.fit_transform(features[numerical_cols]) | |
st.session_state.scaler = scaler # Save the scaler | |
try: | |
X_train, X_test, y_train, y_test = train_test_split( | |
features, target, test_size=test_size, random_state=random_state, | |
stratify=target if st.session_state.problem_type == 'classification' else None | |
) | |
st.session_state.X_train = X_train | |
st.session_state.X_test = X_test | |
st.session_state.y_train = y_train | |
st.session_state.y_test = y_test | |
st.success(f"Data split successfully! Training: {len(X_train)} samples, Testing: {len(X_test)} samples.") | |
st.markdown("---") | |
st.markdown("#### 4. Select Machine Learning Algorithm") | |
if st.session_state.problem_type == 'classification': | |
algo_options = { | |
"Logistic Regression": LogisticRegression(random_state=random_state), | |
"Decision Tree Classifier": DecisionTreeClassifier(random_state=random_state), | |
"Random Forest Classifier": RandomForestClassifier(random_state=random_state), | |
"Support Vector Classifier (SVC)": SVC(random_state=random_state), | |
"K-Nearest Neighbors Classifier": KNeighborsClassifier(), | |
"Gaussian Naive Bayes": GaussianNB() | |
} | |
algo_name = st.selectbox("Choose a Classification Algorithm:", list(algo_options.keys()), key="classification_algo_select") | |
selected_algo = algo_options.get(algo_name) | |
elif st.session_state.problem_type == 'regression': | |
algo_options = { | |
"Linear Regression": LinearRegression(), | |
"Decision Tree Regressor": DecisionTreeRegressor(random_state=random_state), | |
"Random Forest Regressor": RandomForestRegressor(random_state=random_state), | |
"Support Vector Regressor (SVR)": SVR(), | |
"K-Nearest Neighbors Regressor": KNeighborsRegressor() | |
} | |
algo_name = st.selectbox("Choose a Regression Algorithm:", list(algo_options.keys()), key="regression_algo_select") | |
selected_algo = algo_options.get(algo_name) | |
else: | |
st.warning("Please define target column and problem type to select an algorithm.") | |
selected_algo = None | |
if selected_algo: | |
st.info(f"Selected Algorithm: **{algo_name}**") | |
st.session_state.selected_algo = selected_algo | |
st.session_state.selected_algo_name = algo_name | |
st.markdown("---") | |
if st.button("🚀 Train Model"): | |
if st.session_state.X_train is not None and st.session_state.y_train is not None: | |
try: | |
with st.spinner(f"Training {st.session_state.selected_algo_name}..."): | |
st.session_state.selected_algo.fit(st.session_state.X_train, st.session_state.y_train) | |
st.session_state.trained_model = st.session_state.selected_algo | |
st.success(f"Model **{st.session_state.selected_algo_name}** trained successfully!") | |
y_pred = st.session_state.trained_model.predict(st.session_state.X_test) | |
metrics = {} | |
if st.session_state.problem_type == 'classification': | |
metrics['Accuracy'] = accuracy_score(st.session_state.y_test, y_pred) | |
metrics['Precision'] = precision_score(st.session_state.y_test, y_pred, average='weighted', zero_division=0) | |
metrics['Recall'] = recall_score(st.session_state.y_test, y_pred, average='weighted', zero_division=0) | |
metrics['F1 Score'] = f1_score(st.session_state.y_test, y_pred, average='weighted', zero_division=0) | |
metrics['Confusion Matrix'] = confusion_matrix(st.session_state.y_test, y_pred) | |
elif st.session_state.problem_type == 'regression': | |
metrics['Mean Squared Error'] = mean_squared_error(st.session_state.y_test, y_pred) | |
metrics['R2 Score'] = r2_score(st.session_state.y_test, y_pred) | |
st.session_state.model_metrics = metrics | |
st.rerun() | |
except Exception as e: | |
st.error(f"Error training model: {e}") | |
else: | |
st.warning("Please split the data first before training the model.") | |
else: | |
st.warning("Please select a target column and problem type to enable algorithm selection.") | |
except Exception as e: | |
st.error(f"Error splitting data or preparing features: {e}") | |
st.info("Ensure your data is clean and suitable for splitting (e.g., no remaining NaN values after imputation).") | |
else: | |
st.warning("Please select at least one independent variable.") | |
else: | |
st.info("Please select a target column to proceed with data operations.") | |
else: | |
st.info("Please upload a dataset first to access Data Operations & Algorithms.") | |
with col12: | |
with st.expander("🎨 Select Plot Type", expanded=False): | |
if st.session_state.updated_df is not None: | |
df = st.session_state.updated_df | |
numerical_cols = df.select_dtypes(include=np.number).columns.tolist() | |
categorical_cols = df.select_dtypes(include='object').columns.tolist() | |
plot_type = st.selectbox("Select a plot type", ["---Select---", "Bar Chart", "Histogram", "Scatter Plot", "Box Plot", "Heatmap", | |
"Line Chart", "Pie Chart", "Violin Plot", "Pair Plot", | |
"3D Scatter Plot", "3D Surface Plot"]) | |
if plot_type == "Bar Chart": | |
st.info("A bar chart shows counts of categories within a column.") | |
selected_col = st.selectbox("Select a categorical column", categorical_cols) | |
if st.button("Generate Bar Chart"): | |
if selected_col: | |
fig = px.bar(df, x=selected_col, title=f'Bar Chart of {selected_col}', color=selected_col) | |
st.plotly_chart(fig, use_container_width=True) | |
elif plot_type == "Histogram": | |
st.info("A histogram shows the distribution of a numerical column.") | |
selected_col = st.selectbox("Select a numerical column", numerical_cols) | |
if st.button("Generate Histogram"): | |
if selected_col: | |
fig = px.histogram(df, x=selected_col, title=f'Histogram of {selected_col}') | |
st.plotly_chart(fig, use_container_width=True) | |
elif plot_type == "Scatter Plot": | |
st.info("A scatter plot shows the relationship between two numerical columns.") | |
x_col = st.selectbox("Select X-axis column", numerical_cols, key='scatter_x') | |
y_col = st.selectbox("Select Y-axis column", numerical_cols, key='scatter_y') | |
if st.button("Generate Scatter Plot"): | |
if x_col and y_col: | |
fig = px.scatter(df, x=x_col, y=y_col, title=f'Scatter Plot of {x_col} vs {y_col}') | |
st.plotly_chart(fig, use_container_width=True) | |
elif plot_type == "Box Plot": | |
st.info("A box plot shows the distribution of a numerical column grouped by a categorical column.") | |
num_col = st.selectbox("Select a numerical column", numerical_cols, key='box_num') | |
cat_col = st.selectbox("Select a categorical column for grouping", categorical_cols, key='box_cat') | |
if st.button("Generate Box Plot"): | |
if num_col and cat_col: | |
fig = px.box(df, x=cat_col, y=num_col, title=f'Box Plot of {num_col} by {cat_col}', color=cat_col) | |
st.plotly_chart(fig, use_container_width=True) | |
elif plot_type == "Heatmap": | |
st.info("A heatmap shows the correlation between all numerical columns.") | |
if st.button("Generate Heatmap"): | |
corr = df[numerical_cols].corr() | |
fig = px.imshow(corr, text_auto=True, title='Correlation Heatmap') | |
st.plotly_chart(fig, use_container_width=True) | |
elif plot_type == "Line Chart": | |
st.info("A line chart shows trends over time or ordered categories.") | |
x_col = st.selectbox("Select X-axis column", df.columns, key='line_x') | |
y_col = st.selectbox("Select Y-axis (numerical) column", numerical_cols, key='line_y') | |
if st.button("Generate Line Chart"): | |
if x_col and y_col: | |
fig = px.line(df, x=x_col, y=y_col, title=f'Line Chart of {y_col} over {x_col}') | |
st.plotly_chart(fig, use_container_width=True) | |
elif plot_type == "Pie Chart": | |
st.info("A pie chart shows proportions of categories within a column.") | |
selected_col = st.selectbox("Select a categorical column for Pie Chart", categorical_cols, key='pie_col') | |
if st.button("Generate Pie Chart"): | |
if selected_col: | |
pie_data = df[selected_col].value_counts().reset_index() | |
pie_data.columns = [selected_col, 'Count'] | |
fig = px.pie(pie_data, names=selected_col, values='Count', title=f'Pie Chart of {selected_col}') | |
st.plotly_chart(fig, use_container_width=True) | |
elif plot_type == "Violin Plot": | |
st.info("A violin plot shows the distribution of a numerical column by categories.") | |
num_col = st.selectbox("Select a numerical column", numerical_cols, key='violin_num') | |
cat_col = st.selectbox("Select a categorical column for grouping", categorical_cols, key='violin_cat') | |
if st.button("Generate Violin Plot"): | |
if num_col and cat_col: | |
fig = px.violin(df, x=cat_col, y=num_col, box=True, points="all", title=f'Violin Plot of {num_col} by {cat_col}') | |
st.plotly_chart(fig, use_container_width=True) | |
elif plot_type == "Pair Plot": | |
st.info("A pair plot shows scatter plots for all combinations of numerical columns.") | |
if st.button("Generate Pair Plot"): | |
fig = px.scatter_matrix(df[numerical_cols], dimensions=numerical_cols, title='Pair Plot of Numerical Features') | |
st.plotly_chart(fig, use_container_width=True) | |
elif plot_type == "3D Scatter Plot": | |
st.info("A 3D scatter plot shows the relationship between three numerical columns.") | |
x_col = st.selectbox("Select X-axis column", numerical_cols, key='3d_scatter_x') | |
y_col = st.selectbox("Select Y-axis column", numerical_cols, key='3d_scatter_y') | |
z_col = st.selectbox("Select Z-axis column", numerical_cols, key='3d_scatter_z') | |
color_col = st.selectbox("Optional: Select a column for color grouping (optional)", df.columns, key='3d_scatter_color') | |
if st.button("Generate 3D Scatter Plot"): | |
if x_col and y_col and z_col: | |
fig = px.scatter_3d(df, x=x_col, y=y_col, z=z_col, color=color_col if color_col else None, | |
title=f'3D Scatter Plot: {x_col} vs {y_col} vs {z_col}') | |
st.plotly_chart(fig, use_container_width=True) | |
elif plot_type == "3D Surface Plot": | |
st.info("A 3D surface plot shows a continuous surface over two variables.") | |
x_col = st.selectbox("Select X-axis column", numerical_cols, key='3d_surface_x') | |
y_col = st.selectbox("Select Y-axis column", numerical_cols, key='3d_surface_y') | |
z_col = st.selectbox("Select Z-axis column", numerical_cols, key='3d_surface_z') | |
if st.button("Generate 3D Surface Plot"): | |
if x_col and y_col and z_col: | |
try: | |
pivot_table = df.pivot_table(index=y_col, columns=x_col, values=z_col, aggfunc='mean') | |
fig = go.Figure(data=[go.Surface(z=pivot_table.values, | |
x=pivot_table.columns, | |
y=pivot_table.index)]) | |
fig.update_layout(title=f'3D Surface Plot of {z_col} over {x_col} and {y_col}', | |
scene=dict( | |
xaxis_title=x_col, | |
yaxis_title=y_col, | |
zaxis_title=z_col | |
)) | |
st.plotly_chart(fig, use_container_width=True) | |
except Exception as e: | |
st.error(f"Error generating surface plot: {e}") | |
else: | |
st.info("Please upload a dataset first to generate plots.") | |
with col13: | |
with st.expander("📈 Pre Analysis", expanded=False): | |
if st.session_state.updated_df is not None: | |
# Create tabs for different analyses | |
tab1, tab2 = st.tabs(["Statistical Summary", "Dataset Info"]) | |
with tab1: | |
st.subheader("Statistical Summary (describe)") | |
numeric_df = st.session_state.updated_df.select_dtypes(include=['float64', 'int64']) | |
if not numeric_df.empty: | |
# Display statistical summary | |
st.dataframe(numeric_df.describe()) | |
else: | |
st.warning("No numerical columns found in the dataset") | |
if st.checkbox("Show additional statistics"): | |
st.write("Skewness:") | |
st.dataframe(numeric_df.skew()) | |
st.write("Kurtosis:") | |
st.dataframe(numeric_df.kurtosis()) | |
with tab2: | |
st.subheader("Dataset Information (info)") | |
# Get DataFrame info | |
buffer = io.StringIO() | |
st.session_state.updated_df.info(buf=buffer) | |
info_str = buffer.getvalue() | |
# Display formatted info | |
st.text(info_str) | |
st.write("Quick Facts:") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.metric("Total Rows", st.session_state.updated_df.shape[0]) | |
with col2: | |
st.metric("Total Columns", st.session_state.updated_df.shape[1]) | |
with col3: | |
st.metric("Missing Values", st.session_state.updated_df.isna().sum().sum()) | |
# Display column types | |
st.write("Column Data Types:") | |
dtypes_df = pd.DataFrame(st.session_state.updated_df.dtypes, columns=['Data Type']) | |
st.dataframe(dtypes_df) | |
else: | |
st.info("Please upload a dataset first.") | |
#----------------------------------------------------# | |
# Sidebar (Keep as is if you are simulating pages in a single file) | |
with st.sidebar: | |
st.markdown('<b>🛠️ Tools</b>', unsafe_allow_html=True) | |
# Store the active page in session state | |
if 'current_page' not in st.session_state: | |
st.session_state.current_page = "main" | |
if st.button("🏠 Home"): | |
st.session_state.current_page = "main" | |
st.rerun() | |
if st.button("📝 Note -- Lite"): | |
st.session_state.current_page = "note_lite" | |
st.rerun() | |
if st.button("😶🌫️ WordCloud"): | |
st.session_state.current_page = "word_cloud" | |
st.rerun() | |
if st.button("🤖 Viz AI (img)"): | |
st.session_state.current_page = "viz_ai_img" | |
st.rerun() | |
if st.button("🧮 Calculator"): | |
st.session_state.current_page = "calculator" | |
st.rerun() | |
if st.button("⚙️ Viz Editor"): | |
st.session_state.current_page = "note_edit" | |
# No rerun here — handled differently maybe? | |
if st.button("📄 Viz Report"): | |
st.session_state.current_page = "generate_report" | |
st.rerun() | |
st.markdown("<hr>",unsafe_allow_html=True) | |
st.markdown("### <center>Other Products</center>", unsafe_allow_html=True) | |
#---------------------------------------------------------------# | |
#---------------------------------------------------------------# | |
# Main content columns | |
col_main_left, col_main_right = st.columns([0.6, 0.4]) # Adjusted column widths for better layout | |
with col_main_left: | |
st.markdown("<b style='font-size:20px;'>📂 Upload Your Dataset</b>", unsafe_allow_html=True) | |
dataset = st.file_uploader("Choose a dataset file", type=["csv", "xlsx", "txt"], key="file_uploader_main") # Added key | |
if dataset is not None: | |
if 'last_uploaded_file_object' not in st.session_state or st.session_state.last_uploaded_file_object != dataset: | |
st.session_state.last_uploaded_file_object = dataset | |
st.session_state.original_df_uploaded = False | |
st.session_state.updated_df = None | |
st.session_state.X_train = st.session_state.X_test = st.session_state.y_train = st.session_state.y_test = None | |
st.session_state.target_column = None | |
st.session_state.feature_columns = None | |
st.session_state.problem_type = None | |
st.session_state.trained_model = None | |
st.session_state.model_metrics = None | |
st.session_state.scaler = None | |
st.success("✅ File uploaded successfully!") | |
st.write(f"File name: **{dataset.name}**") | |
try: | |
if dataset.name.endswith(".csv"): | |
df = pd.read_csv(dataset) | |
elif dataset.name.endswith(".xlsx"): | |
df = pd.read_excel(dataset) | |
elif dataset.name.endswith(".txt"): | |
df = pd.read_csv(dataset, delimiter="\t") | |
else: | |
st.error("Unsupported file type. Please upload a CSV, XLSX, or TXT (tab-separated) file.") | |
df = None | |
if df is not None: | |
st.session_state.updated_df = df.copy() | |
st.session_state.original_df_uploaded = True | |
st.rerun() | |
except Exception as e: | |
st.error(f"Error reading file: {e}. Please ensure it's a valid CSV, XLSX, or tab-separated TXT.") | |
st.session_state.original_df_uploaded = False | |
st.session_state.updated_df = None | |
# Original Dataset Preview | |
if st.session_state.original_df_uploaded and st.session_state.updated_df is not None: | |
st.markdown('<div class="dataset-preview">', unsafe_allow_html=True) | |
st.subheader("🔍 Original Dataset Preview") | |
st.dataframe(st.session_state.updated_df, use_container_width=True) | |
st.markdown('</div>', unsafe_allow_html=True) | |
# Updated Dataset Preview (after imputation) | |
st.markdown('<div class="dataset-preview">', unsafe_allow_html=True) | |
st.subheader("🔄 Updated Dataset Preview (After Imputation)") | |
st.dataframe(st.session_state.updated_df, use_container_width=True) | |
st.markdown('</div>', unsafe_allow_html=True) | |
with col_main_right: | |
if st.session_state.updated_df is not None: | |
st.markdown('<div class="section-title">📊 Missing Values Report</div>', unsafe_allow_html=True) | |
null_counts = st.session_state.updated_df.isnull().sum() | |
total_nulls = null_counts.sum() | |
if total_nulls == 0: | |
st.success("✅ No null values found in the dataset!") | |
else: | |
st.warning(f"⚠️ Found {total_nulls} null values in the dataset.") | |
st.write(null_counts[null_counts > 0]) | |
# Automatic Missing Value Handling | |
st.markdown('<div class="section-title">🤖 Automatic Missing Value Handling</div>', unsafe_allow_html=True) | |
with st.form("auto_impute_form"): | |
st.write("Apply default handling for all missing values:") | |
auto_impute_option = st.selectbox( | |
"Choose imputation method:", | |
["None", "Mean (Numerical)", "Median (Numerical)", "Mode (All)", "Forward Fill", "Backward Fill"], | |
key="auto_impute_method" | |
) | |
auto_impute_button = st.form_submit_button("Apply Automatic Imputation") | |
if auto_impute_button and auto_impute_option != "None": | |
df_to_impute = st.session_state.updated_df.copy() | |
if auto_impute_option == "Mean (Numerical)": | |
for col in df_to_impute.select_dtypes(include=['number']).columns: | |
if df_to_impute[col].isnull().sum() > 0: | |
df_to_impute[col].fillna(df_to_impute[col].mean(), inplace=True) | |
elif auto_impute_option == "Median (Numerical)": | |
for col in df_to_impute.select_dtypes(include=['number']).columns: | |
if df_to_impute[col].isnull().sum() > 0: | |
df_to_impute[col].fillna(df_to_impute[col].median(), inplace=True) | |
elif auto_impute_option == "Mode (All)": | |
for col in df_to_impute.columns: | |
if df_to_impute[col].isnull().sum() > 0: | |
if not df_to_impute[col].mode().empty: | |
df_to_impute[col].fillna(df_to_impute[col].mode()[0], inplace=True) | |
else: | |
st.warning(f"Could not compute mode for column '{col}'. Skipping.") | |
elif auto_impute_option == "Forward Fill": | |
df_to_impute.fillna(method='ffill', inplace=True) | |
elif auto_impute_option == "Backward Fill": | |
df_to_impute.fillna(method='bfill', inplace=True) | |
st.session_state.updated_df = df_to_impute | |
st.success(f"🎉 Missing values have been handled automatically using **{auto_impute_option}**!") | |
st.rerun() | |
# Manual Missing Value Handling | |
st.markdown('<div class="section-title">🛠️ Manual Missing Value Handling</div>', unsafe_allow_html=True) | |
cols_with_missing = st.session_state.updated_df.columns[st.session_state.updated_df.isnull().any()].tolist() | |
if cols_with_missing: | |
selected_col_manual = st.selectbox( | |
"Select a column to manually handle missing values:", | |
["--- Select a Column ---"] + cols_with_missing, | |
key="manual_col_select" | |
) | |
if selected_col_manual != "--- Select a Column ---": | |
col_dtype = st.session_state.updated_df[selected_col_manual].dtype | |
num_missing = st.session_state.updated_df[selected_col_manual].isnull().sum() | |
st.write(f"Column: **{selected_col_manual}** (Missing values: **{num_missing}**)") | |
with st.form(key=f"manual_impute_form_{selected_col_manual}"): | |
fill_value_to_apply = None | |
if col_dtype == "object": | |
manual_fill_option = st.selectbox( | |
f"Choose a method for '{selected_col_manual}'", | |
["Mode", "Fill with custom value"], | |
key=f"cat_method_{selected_col_manual}" | |
) | |
if manual_fill_option == "Fill with custom value": | |
fill_value_to_apply = st.text_input(f"Enter the custom value to fill for '{selected_col_manual}'", key=f"cat_value_{selected_col_manual}") | |
elif manual_fill_option == "Mode": | |
if not st.session_state.updated_df[selected_col_manual].mode().empty: | |
fill_value_to_apply = st.session_state.updated_df[selected_col_manual].mode()[0] | |
else: | |
st.warning(f"Mode cannot be calculated for {selected_col_manual}. Please enter a custom value.") | |
else: | |
manual_fill_option = st.selectbox( | |
f"Choose a method for '{selected_col_manual}'", | |
["Mean", "Median", "Mode", "Fill with custom value"], | |
key=f"num_method_{selected_col_manual}" | |
) | |
if manual_fill_option == "Fill with custom value": | |
fill_value_to_apply = st.number_input(f"Enter the custom value to fill for '{selected_col_manual}'", value=0.0, key=f"num_value_{selected_col_manual}") | |
elif manual_fill_option == "Mean": | |
fill_value_to_apply = st.session_state.updated_df[selected_col_manual].mean() | |
elif manual_fill_option == "Median": | |
fill_value_to_apply = st.session_state.updated_df[selected_col_manual].median() | |
elif manual_fill_option == "Mode": | |
if not st.session_state.updated_df[selected_col_manual].mode().empty: | |
fill_value_to_apply = st.session_state.updated_df[selected_col_manual].mode()[0] | |
else: | |
st.warning(f"Mode cannot be calculated for {selected_col_manual}. Please enter a custom value.") | |
submit_button = st.form_submit_button(f"Apply Manual Imputation to {selected_col_manual}") | |
if submit_button and fill_value_to_apply is not None: | |
st.session_state.updated_df[selected_col_manual].fillna(fill_value_to_apply, inplace=True) | |
st.success(f"Filled '{selected_col_manual}' missing values with **'{fill_value_to_apply}'** using {manual_fill_option}!") | |
st.rerun() | |
else: | |
st.info("No columns with missing values to display for manual handling.") | |
# Pair Plot button is now below the missing values report | |
st.markdown("---") | |
if st.button("📈 Generate Pair Plot of Numerical Columns"): | |
if st.session_state.updated_df is not None: | |
numerical_data = st.session_state.updated_df.select_dtypes(include=['float64', 'int64']) | |
if not numerical_data.empty: | |
st.markdown("##### 📘 Pair Plot - Seaborn (Static)", unsafe_allow_html=True) | |
fig1 = sns.pairplot(numerical_data) | |
st.pyplot(fig1) | |
plt.clf() | |
st.markdown("##### 🧠 Pair Plot - Plotly (Interactive)", unsafe_allow_html=True) | |
fig2 = px.scatter_matrix(numerical_data, | |
dimensions=numerical_data.columns, | |
height=800, width=800) | |
st.plotly_chart(fig2, use_container_width=True) | |
else: | |
st.warning("No numerical columns found to generate a pair plot.") | |
else: | |
st.warning("Please upload and process a dataset first.") | |
# --- Machine Learning Operations Section (Full Width, below the two columns) --- | |
st.markdown("---") | |
st.markdown("<h2 style='text-align: center; color: #4A90E2;'>🧠 Machine Learning Operations</h2>", unsafe_allow_html=True) | |
if st.session_state.updated_df is not None and st.session_state.trained_model is not None: | |
st.markdown(f"### Model Training Results for **{st.session_state.selected_algo_name}**") | |
if st.session_state.model_metrics: | |
if st.session_state.problem_type == 'classification': | |
st.markdown("#### Classification Metrics:") | |
col_m1, col_m2, col_m3, col_m4 = st.columns(4) | |
with col_m1: | |
st.metric(label="Accuracy", value=f"{st.session_state.model_metrics['Accuracy']:.4f}") | |
with col_m2: | |
st.metric(label="Precision", value=f"{st.session_state.model_metrics['Precision']:.4f}") | |
with col_m3: | |
st.metric(label="Recall", value=f"{st.session_state.model_metrics['Recall']:.4f}") | |
with col_m4: | |
st.metric(label="F1 Score", value=f"{st.session_state.model_metrics['F1 Score']:.4f}") | |
st.markdown("#### Confusion Matrix:") | |
fig_cm, ax_cm = plt.subplots(figsize=(6, 5)) | |
sns.heatmap(st.session_state.model_metrics['Confusion Matrix'], annot=True, fmt='d', cmap='Blues', ax=ax_cm) | |
ax_cm.set_xlabel('Predicted') | |
ax_cm.set_ylabel('True') | |
ax_cm.set_title('Confusion Matrix') | |
st.pyplot(fig_cm) | |
plt.clf() | |
elif st.session_state.problem_type == 'regression': | |
st.markdown("#### Regression Metrics:") | |
col_r1, col_r2 = st.columns(2) | |
with col_r1: | |
st.metric(label="Mean Squared Error", value=f"{st.session_state.model_metrics['Mean Squared Error']:.4f}") | |
with col_r2: | |
st.metric(label="R2 Score", value=f"{st.session_state.model_metrics['R2 Score']:.4f}") | |
st.markdown("---") | |
# --- Test Your Own Values and Download Model --- | |
col_test, col_download = st.columns(2) | |
with col_test: | |
st.markdown("### 🧪 Test with Your Own Values") | |
if st.session_state.feature_columns: | |
input_data = {} | |
for col in st.session_state.feature_columns: | |
if st.session_state.updated_df[col].dtype == 'object': | |
unique_vals = st.session_state.updated_df[col].unique() | |
input_data[col] = st.selectbox(f"Select value for **{col}**", unique_vals) | |
else: | |
input_data[col] = st.number_input(f"Enter value for **{col}**", value=float(st.session_state.updated_df[col].mean())) | |
if st.button("Get Prediction"): | |
input_df = pd.DataFrame([input_data]) | |
# Preprocess the input data similarly to the training data | |
for col in input_df.select_dtypes(include=['object', 'bool']).columns: | |
le = LabelEncoder() | |
input_df[col] = le.fit_transform(input_df[col].astype(str)) | |
if st.session_state.scaler: | |
numerical_cols = input_df.select_dtypes(include=['number']).columns | |
if not numerical_cols.empty: | |
input_df[numerical_cols] = st.session_state.scaler.transform(input_df[numerical_cols]) | |
prediction = st.session_state.trained_model.predict(input_df) | |
st.success(f"**Prediction:** {prediction[0]}") | |
with col_download: | |
st.markdown("### 📥 Download Trained Model") | |
# Serialize the model for download | |
model_pkl = pickle.dumps(st.session_state.trained_model) | |
b64 = base64.b64encode(model_pkl).decode() | |
st.download_button( | |
label="Download Model as .pkl", | |
data=base64.b64decode(b64), | |
file_name=f"{st.session_state.selected_algo_name}_model.pkl", | |
mime="application/octet-stream" | |
) | |
else: | |
st.info("Upload a dataset and train a model to see results and test your own values.") | |
if st.session_state.current_page == "viz_ai_img": | |
viz_ai_img.analyze_image_ui() | |
elif st.session_state.current_page == "word_cloud": | |
# Make sure to import your word_cloud module if you have it | |
word_cloud.render_word_cloud_page() | |
elif st.session_state.current_page == "note_lite": | |
notepad_lite.render_notepad() | |
elif st.session_state.current_page == "calculator": | |
calculator.render_calculator() | |
elif st.session_state.current_page == "generate_report": | |
# Make sure to import your viz_report module if you have it | |
# viz_report.generate_report() | |
#viz_report.render_report_page() | |
st.write("Viz Report Page (Implement logic here)") | |
# Add custom CSS for better styling | |
st.markdown(""" | |
<style> | |
.stButton>button { | |
width: 100%; | |
border-radius: 5px; | |
border: 1px solid #4A90E2; | |
color: #4A90E2; | |
background-color: white; | |
padding: 10px; | |
font-size: 16px; | |
transition: all 0.2s ease-in-out; | |
} | |
.stButton>button:hover { | |
background-color: #4A90E2; | |
color: white; | |
} | |
.section-title { | |
color: #4A90E2; | |
font-size: 18px; | |
margin-top: 15px; | |
margin-bottom: 10px; | |
font-weight: bold; | |
} | |
.dataset-preview { | |
border: 1px solid #ddd; | |
border-radius: 5px; | |
padding: 10px; | |
margin-top: 20px; | |
background-color: #f9f9f9; | |
} | |
h1 { | |
color: #4A90E2; | |
} | |
h2 { | |
color: #4A90E2; | |
} | |
h3 { | |
color: #333; | |
} | |
h4 { | |
color: #555; | |
} | |
.st-emotion-cache-1jmvejs { # Targeting expander header for slightly different styling | |
background-color: #f0f2f6; | |
border-radius: 5px; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
st.markdown(""" | |
<div style="position: fixed; bottom: 0; left: 0; width: 100%; text-align: center; background-color: ; padding: 10px;"> | |
<p style="font-size: 12px;">Made with ❤️ by <a href = "https://avarshvir.github.io/arshvir">Arshvir</a> and <a href = "https://jaiho-labs.onrender.com">Jaiho Labs</a></p> | |
</div> | |
""", unsafe_allow_html=True) |