import streamlit as st import pandas as pd import numpy as np import re import matplotlib.pyplot as plt import seaborn as sns import nltk from nltk.corpus import stopwords from nltk.stem.snowball import SnowballStemmer import pickle from transformers import pipeline as hf_pipeline from sklearn.utils.multiclass import type_of_target import io import base64 from sklearn.model_selection import train_test_split from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.pipeline import Pipeline from sklearn.multiclass import OneVsRestClassifier from sklearn.linear_model import LogisticRegression from sklearn.naive_bayes import MultinomialNB from sklearn.metrics import roc_auc_score, accuracy_score, classification_report from textblob import TextBlob import warnings from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix from sklearn.metrics import roc_curve warnings.filterwarnings('ignore') # Download required NLTK resources try: nltk.data.find('corpora/stopwords') except LookupError: nltk.download('stopwords') # Initialize the stemmer stemmer = SnowballStemmer('english') stop_words_set = set(stopwords.words('english')) # Text preprocessing functions def remove_stopwords(text): return " ".join([word for word in str(text).split() if word.lower() not in stop_words_set]) def train_lightweight_model(data, text_column, label_column): """ Train a lightweight model for toxicity detection Args: data: DataFrame containing the data text_column: Column name for text data label_column: Column name for label data Returns: Trained model and vectorizer """ from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.linear_model import LogisticRegression from sklearn.pipeline import Pipeline # Preprocess text data['processed_text'] = data[text_column].apply(preprocess_text) # Create a pipeline with TF-IDF and Logistic Regression model = Pipeline([ ('tfidf', TfidfVectorizer(max_features=5000, ngram_range=(1, 2))), ('clf', LogisticRegression(random_state=42, max_iter=1000)) ]) # Train the model model.fit(data['processed_text'], data[label_column]) return model def load_bert_model(): """ Load a pre-trained BERT model for sentiment analysis Returns: Loaded model """ try: # Load sentiment analysis model from HuggingFace sentiment_analyzer = hf_pipeline("sentiment-analysis") st.success("BERT model loaded successfully!") return sentiment_analyzer except Exception as e: st.error(f"Error loading BERT model: {e}") return None def clean_text(text): text = str(text).lower() text = re.sub(r"what's", "what is ", text) text = re.sub(r"\'s", " ", text) text = re.sub(r"\'ve", " have ", text) text = re.sub(r"can't", "can not ", text) text = re.sub(r"n't", " not ", text) text = re.sub(r"i'm", "i am ", text) text = re.sub(r"\'re", " are ", text) text = re.sub(r"\'d", " would ", text) text = re.sub(r"\'ll", " will ", text) text = re.sub(r"\'scuse", " excuse ", text) text = re.sub(r'\W', ' ', text) # Remove non-word characters text = re.sub(r'\s+', ' ', text).strip() # Remove extra spaces return text def stemming(sentence): return " ".join([stemmer.stem(word) for word in str(sentence).split()]) def preprocess_text(text): text = remove_stopwords(text) text = clean_text(text) text = stemming(text) return text # Function to get sentiment def get_sentiment(text): score = TextBlob(text).sentiment.polarity if score > 0: return "Positive", score elif score < 0: return "Negative", score else: return "Neutral", score # Function to moderate text based on toxicity def moderate_text(text, predictions, threshold_moderate=0.5, threshold_delete=0.8): # If binary classification, only check the 'toxic' probability (index 1) if len(predictions) == 2: toxic_prob = predictions[1] if toxic_prob >= threshold_delete: return "*** COMMENT DELETED DUE TO HIGH TOXICITY ***", "delete" elif toxic_prob >= threshold_moderate: # List of potentially toxic words to censor toxic_words = ["stupid", "idiot", "dumb", "hate", "sucks", "terrible", "awful", "garbage", "trash", "pathetic", "ridiculous"] words = text.split() moderated_words = [] for word in words: # Clean word for comparison clean_word = re.sub(r'[^\w\s]', '', word.lower()) # Check if the word is in the toxic words list if clean_word in toxic_words: # Replace with a more neutral placeholder moderated_words.append("[inappropriate]") else: moderated_words.append(word) return " ".join(moderated_words), "moderate" else: return text, "keep" else: # Multi-label: check all classes if any(pred >= threshold_delete for pred in predictions): return "*** COMMENT DELETED DUE TO HIGH TOXICITY ***", "delete" elif any(pred >= threshold_moderate for pred in predictions): # List of potentially toxic words to censor toxic_words = ["stupid", "idiot", "dumb", "hate", "sucks", "terrible", "awful", "garbage", "trash", "pathetic", "ridiculous"] words = text.split() moderated_words = [] for word in words: # Clean word for comparison clean_word = re.sub(r'[^\w\s]', '', word.lower()) # Check if the word is in the toxic words list if clean_word in toxic_words: # Replace with a more neutral placeholder moderated_words.append("[inappropriate]") else: moderated_words.append(word) return " ".join(moderated_words), "moderate" else: return text, "keep" # Function to train and save the model def train_model(X_train, y_train, model_type='logistic_regression'): st.write("Training model...") # Ensure `y_train` has 6 columns label_columns = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] # Create missing columns if they don't exist for col in label_columns: if col not in y_train.columns: y_train[col] = 0 # Ensure columns are in the right order y_train = y_train[label_columns] if model_type == 'logistic_regression': pipeline = Pipeline([ ('tfidf', TfidfVectorizer(stop_words='english', max_features=50000)), ('clf', OneVsRestClassifier(LogisticRegression(max_iter=1000), n_jobs=-1)) ]) else: # Naive Bayes pipeline = Pipeline([ ('tfidf', TfidfVectorizer(stop_words='english', max_features=50000)), ('clf', OneVsRestClassifier(MultinomialNB(), n_jobs=-1)) ]) pipeline.fit(X_train, y_train) return pipeline def evaluate_model(pipeline, X_test, y_test): """ Evaluates the given trained pipeline on test data. Returns: accuracy: Accuracy score roc_auc: ROC AUC score predictions: Predicted labels pred_probs: Predicted probabilities fpr: False Positive Rate array (for ROC curve) tpr: True Positive Rate array (for ROC curve) """ # Get predictions and prediction probabilities predictions = pipeline.predict(X_test) pred_probs = pipeline.predict_proba(X_test) if isinstance(pred_probs, list) and len(pred_probs) == 1: pred_probs = pred_probs[0] # Handle list with 1 array # Ensure predictions format matches y_test y_type = type_of_target(y_test) pred_type = type_of_target(predictions) if y_type != pred_type: if y_type == "multilabel-indicator" and pred_type == "binary": # Expand binary predictions to multilabel shape predictions = np.array([[pred] * y_test.shape[1] for pred in predictions]) elif y_type == "binary" and pred_type == "multilabel-indicator": # Collapse multilabel predictions to binary predictions = predictions[:, 0] # Calculate accuracy accuracy = accuracy_score(y_test, predictions) # Calculate ROC AUC try: if len(y_test.shape) > 1 and y_test.shape[1] > 1: # Multi-label case roc_auc_sum = 0 valid_labels = 0 for i in range(y_test.shape[1]): try: roc_auc_sum += roc_auc_score(y_test.iloc[:, i], pred_probs[:, i]) valid_labels += 1 except Exception: continue roc_auc = roc_auc_sum / valid_labels if valid_labels > 0 else 0.0 else: # Binary case roc_auc = roc_auc_score(y_test, pred_probs[:, 1] if pred_probs.ndim > 1 and pred_probs.shape[1] > 1 else pred_probs) except Exception as e: print(f"Warning: Could not compute ROC AUC - {e}") roc_auc = 0.0 # Calculate FPR, TPR for ROC curve (only for binary classification) try: if len(y_test.shape) == 1 or (len(y_test.shape) > 1 and y_test.shape[1] == 1): fpr, tpr, _ = roc_curve( y_test, pred_probs[:, 1] if pred_probs.ndim > 1 and pred_probs.shape[1] > 1 else pred_probs ) else: fpr, tpr = None, None # ROC Curve not available for multilabel except Exception as e: print(f"Warning: Could not compute ROC Curve - {e}") fpr, tpr = None, None return accuracy, roc_auc, predictions, pred_probs, fpr, tpr # Function to create a download link for the trained model def get_model_download_link(model, filename): model_bytes = pickle.dumps(model) b64 = base64.b64encode(model_bytes).decode() href = f'Download Trained Model' return href # Function to plot toxicity distribution def plot_toxicity_distribution(df, toxicity_columns): fig, ax = plt.subplots(figsize=(10, 6)) x = df[toxicity_columns].sum() sns.barplot(x=x.index, y=x.values, alpha=0.8, palette='viridis', ax=ax) plt.title('Toxicity Distribution') plt.ylabel('Count') plt.xlabel('Toxicity Category') plt.xticks(rotation=45) return fig # Function to provide sample data format def show_sample_data_format(): st.subheader("Sample Data Format") # Create sample dataframe sample_data = { 'comment_text': [ "This is a normal comment.", "This is a toxic comment you idiot!", "You're all worthless and should die.", "I respectfully disagree with your point." ], 'toxic': [0, 1, 1, 0], 'severe_toxic': [0, 0, 1, 0], 'obscene': [0, 1, 0, 0], 'threat': [0, 0, 1, 0], 'insult': [0, 1, 1, 0], 'identity_hate': [0, 0, 0, 0] } sample_df = pd.DataFrame(sample_data) st.dataframe(sample_df) # Create download link for sample data csv = sample_df.to_csv(index=False) b64 = base64.b64encode(csv.encode()).decode() href = f'Download Sample CSV' st.markdown(href, unsafe_allow_html=True) st.info(""" Your CSV file should contain: 1. A column with comment text 2. One or more columns with binary values (0 or 1) for each toxicity category """) # Function to validate dataset def validate_dataset(df, comment_column, toxicity_columns): issues = [] # Check if comment column exists if comment_column not in df.columns: issues.append(f"Comment column '{comment_column}' not found in the dataset") # Check if toxicity columns exist missing_columns = [col for col in toxicity_columns if col not in df.columns] if missing_columns: issues.append(f"Missing toxicity columns: {', '.join(missing_columns)}") # Check if values in toxicity columns are valid (0 or 1) for col in toxicity_columns: if col in df.columns: # Check for non-numeric values if not pd.api.types.is_numeric_dtype(df[col]): issues.append(f"Column '{col}' contains non-numeric values") else: # Check for values other than 0 and 1 invalid_values = df[col].dropna().apply(lambda x: x not in [0, 1, 0.0, 1.0]) if invalid_values.any(): issues.append(f"Column '{col}' contains values other than 0 and 1") # Check for empty data if df.empty: issues.append("Dataset is empty") elif df[comment_column].isna().all(): issues.append("Comment column contains no data") return issues # Function to extract predictions from model output def extract_predictions(predictions_proba, toxicity_categories): """ Helper function to extract probabilities from model output, handling different output formats. """ # Debug information if st.session_state.debug_mode: st.write(f"Predictions type: {type(predictions_proba)}") st.write( f"Predictions shape/length: {np.shape(predictions_proba) if hasattr(predictions_proba, 'shape') else len(predictions_proba)}") # Case 1: List of arrays with one element per toxicity category if isinstance(predictions_proba, list) and len(predictions_proba) == len(toxicity_categories): return [pred_array[:, 1][0] if pred_array.shape[1] > 1 else pred_array[0] for pred_array in predictions_proba] # Case 2: List with a single array (common for OneVsRestClassifier) elif isinstance(predictions_proba, list) and len(predictions_proba) == 1: pred_array = predictions_proba[0] # If it's a 2D array with number of columns equal to number of categories if len(pred_array.shape) == 2 and pred_array.shape[1] == len(toxicity_categories): return pred_array[0] # Return first row, which contains all probabilities # If it's a 2D array with 2 columns per category (common binary classifier output) elif len(pred_array.shape) == 2 and pred_array.shape[1] == 2: return np.array([pred_array[0, 1]]) # Case 3: Direct numpy array elif isinstance(predictions_proba, np.ndarray): # If it's already the right shape if len(predictions_proba.shape) == 2 and predictions_proba.shape[1] == len(toxicity_categories): return predictions_proba[0] # If it's a 2D array with two columns (binary classification) elif len(predictions_proba.shape) == 2 and predictions_proba.shape[1] == 2: # For binary classification, return the probability of positive class return np.array([predictions_proba[0, 1]]) # If prediction format isn't recognized, return a repeated array of single probability # This handles the case where we only have one prediction but need to repeat it if isinstance(predictions_proba, list) and len(predictions_proba) == 1: single_prob = predictions_proba[0] if hasattr(single_prob, 'shape') and len(single_prob.shape) == 2 and single_prob.shape[1] == 2: # Take positive class probability and repeat for all categories return np.full(len(toxicity_categories), single_prob[0, 1]) # Last resort fallback st.warning(f"Unexpected prediction format. Creating default predictions.") return np.zeros(len(toxicity_categories)) def display_classification_result(result): st.subheader("Classification Result") # Show original and moderated text side by side col1, col2 = st.columns(2) with col1: st.markdown("**Original Text**") st.code(result["original_text"], language="text") with col2: st.markdown("**Moderated Text**") st.code(result["moderated_text"], language="text") # Show action with color action = result["action"] if action == "keep": st.success("✅ This comment is allowed (Non-toxic).") elif action == "moderate": st.warning("⚠️ This comment is moderated (Potentially toxic).") elif action == "delete": st.error("🚫 This comment is deleted (Highly toxic).") # Show toxicity scores st.markdown("**Toxicity Scores:**") score_cols = st.columns(len(result["toxicity_scores"])) for i, (label, score) in enumerate(result["toxicity_scores"].items()): score_cols[i].metric(label.capitalize(), f"{score:.2%}") # Show sentiment if available if "sentiment" in result: st.markdown("**Sentiment Analysis:**") st.info(f"{result['sentiment']['label']} (score: {result['sentiment']['score']:.2%})") def moderate_comment(comment, model, sentiment_model=None): """ Moderate a single comment using the trained model and optionally BERT sentiment analysis. Args: comment: The comment text to moderate model: The trained model to use for toxicity detection sentiment_model: Optional BERT model for sentiment analysis Returns: Dictionary containing moderation results """ # Preprocess the comment processed_text = preprocess_text(comment) # Get model predictions predictions = model.predict_proba([processed_text])[0] # Get sentiment if BERT model is available sentiment = None if sentiment_model: sentiment = sentiment_model(comment)[0] # Moderate the text moderated_text, action = moderate_text(comment, predictions) # Prepare result result = { "original_text": comment, "moderated_text": moderated_text, "action": action, "toxicity_scores": {} } # Handle both binary and multi-class predictions if len(predictions) == 2: # Binary classification result["toxicity_scores"] = { "toxic": float(predictions[1]), # Probability of positive class "non_toxic": float(predictions[0]) # Probability of negative class } else: # Multi-class classification # Define the toxicity categories categories = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"] # Add scores for each category that exists in the predictions for i, category in enumerate(categories): if i < len(predictions): result["toxicity_scores"][category] = float(predictions[i]) if sentiment: result["sentiment"] = { "label": sentiment["label"], "score": float(sentiment["score"]) } return result # --- Bias Detection Module --- def detect_subgroup(text): gender_keywords = ["he", "she", "him", "her", "man", "woman", "boy", "girl", "male", "female"] ethnicity_keywords = [ "asian", "black", "white", "hispanic", "latino", "indian", "african", "european", "arab", "jewish", "muslim", "christian" ] text_lower = text.lower() subgroups = set() if any(word in text_lower for word in gender_keywords): subgroups.add("gender") if any(word in text_lower for word in ethnicity_keywords): subgroups.add("ethnicity") return list(subgroups) def bias_report(X, y_true, y_pred, text_column_name): # X: DataFrame with text column # y_pred: predicted labels (same shape as y_true) # y_true: true labels (same shape as y_pred) # text_column_name: name of the text column results = [] for idx, row in X.iterrows(): subgroups = detect_subgroup(row[text_column_name]) if subgroups: for subgroup in subgroups: results.append({ "subgroup": subgroup, "is_toxic": int(y_pred[idx].sum() > 0) if len(y_pred.shape) > 1 else int(y_pred[idx] > 0) }) if not results: return "No sensitive subgroups detected in the evaluation set." df = pd.DataFrame(results) report = "" for subgroup in df["subgroup"].unique(): total = (df["subgroup"] == subgroup).sum() toxic = df[(df["subgroup"] == subgroup) & (df["is_toxic"] == 1)].shape[0] rate = toxic / total if total > 0 else 0 report += f"- **{subgroup.capitalize()}**: {toxic}/{total} ({rate:.1%}) flagged as toxic\n" return report # Streamlit app def main(): st.set_page_config( page_title="Toxic Comment Classifier", page_icon="🧊", layout="wide", initial_sidebar_state="expanded", ) col1, col2 = st.columns([1, 4]) # Adjust the ratio as needed with col1: st.image("logo.jpeg", width=100) # Smaller width fits better with col2: st.title("Toxic Comment Classifier") # Initialize session state variables if they don't exist if 'data' not in st.session_state: st.session_state.data = None if 'model' not in st.session_state: st.session_state.model = None if 'vectorizer' not in st.session_state: st.session_state.vectorizer = None if 'predictions' not in st.session_state: st.session_state.predictions = None if 'lightweight_model' not in st.session_state: st.session_state.lightweight_model = None if 'bert_model' not in st.session_state: st.session_state.bert_model = None # Create sidebar navigation st.sidebar.title("Navigation") page = st.sidebar.radio( "Select a page", ["Home", "Data Preprocessing", "Model Training", "Model Evaluation", "Prediction", "Visualization"] ) # Home page if page == "Home": st.header("Home") st.write(""" Welcome to the Toxic Comment Classifier application. This tool helps you to: 1. Upload and preprocess data 2. Train a machine learning model to detect toxic comments 3. Evaluate model performance 4. Make predictions on new data 5. Visualize results Please use the sidebar navigation to get started. """) # Add option to load BERT sentiment model if st.sidebar.checkbox("Use BERT for Sentiment Analysis"): st.subheader("BERT-Based Sentiment Analysis") st.write("This option uses a pre-trained BERT model for advanced sentiment analysis.") if st.button("Load BERT Model"): with st.spinner("Loading BERT model..."): st.session_state.bert_model = load_bert_model() st.write("DEBUG: bert_model in session_state after loading:", st.session_state.bert_model) # Sample data section st.subheader("Sample Data Format") show_sample_data_format() # Single comment moderation st.subheader("Try Comment Moderation") comment = st.text_area("Enter a comment to moderate:") col1, col2 = st.columns(2) with col1: use_default_model = st.checkbox("Use built-in model for demo", value=True) with col2: use_bert = st.checkbox("Use BERT model for sentiment (if loaded)", value=False) if st.button("Moderate Comment"): if comment: with st.spinner("Analyzing comment..."): st.write("DEBUG: bert_model in session_state before use:", st.session_state.bert_model) sentiment_model = st.session_state.bert_model if use_bert and st.session_state.bert_model is not None else None if use_default_model or st.session_state.model or st.session_state.lightweight_model: model_to_use = None if st.session_state.model: model_to_use = st.session_state.model elif st.session_state.lightweight_model: model_to_use = st.session_state.lightweight_model result = moderate_comment(comment, model_to_use, sentiment_model) display_classification_result(result) else: st.error("No model available. Please train a model first or enable the demo model.") else: st.warning("Please enter a comment to moderate.") # Data Preprocessing page elif page == "Data Preprocessing": st.header("Data Preprocessing") # File upload st.subheader("Upload Dataset") uploaded_file = st.file_uploader("Choose a CSV file", type="csv") if uploaded_file is not None: try: # Load data data = pd.read_csv(uploaded_file) # Display raw data st.subheader("Raw Data") st.write(data.head()) # Validate the dataset validation_result = validate_dataset(data, 'comment_text', ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']) if not validation_result: # Empty list means no issues st.success("Dataset is valid!") # Store data in session state st.session_state.data = data # Data cleaning st.subheader("Data Cleaning") st.write("Select columns to include in the analysis:") # Get all columns all_columns = data.columns.tolist() default_selected = ["comment_text", "toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"] default_selected = [col for col in default_selected if col in all_columns] selected_columns = st.multiselect( "Select columns", options=all_columns, default=default_selected ) if selected_columns: # Filter data by selected columns filtered_data = data[selected_columns] # Display cleaned data st.subheader("Filtered Data") st.write(filtered_data.head()) # Display data statistics st.subheader("Data Statistics") st.write(filtered_data.describe()) # Check for missing values st.subheader("Missing Values") missing_values = filtered_data.isnull().sum() st.write(missing_values) # Handle missing values if any if missing_values.sum() > 0: st.warning("There are missing values in the dataset.") if st.button("Handle Missing Values"): # Fill missing text with empty string text_columns = [col for col in selected_columns if filtered_data[col].dtype == 'object'] for col in text_columns: filtered_data[col] = filtered_data[col].fillna("") # Fill missing numerical values with 0 numerical_columns = [col for col in selected_columns if filtered_data[col].dtype != 'object'] for col in numerical_columns: filtered_data[col] = filtered_data[col].fillna(0) st.success("Missing values handled!") st.write(filtered_data.isnull().sum()) # Text preprocessing st.subheader("Text Preprocessing") # Select text column text_columns = [col for col in selected_columns if filtered_data[col].dtype == 'object'] if text_columns: text_column = st.selectbox("Select text column for preprocessing", text_columns) # Show sample of original text st.write("Sample original text:") sample_texts = filtered_data[text_column].head(3).tolist() for i, text in enumerate(sample_texts): st.text(f"Text {i + 1}: {text[:200]}...") # Preprocess text if st.button("Preprocess Text"): with st.spinner("Preprocessing text..."): filtered_data['processed_text'] = filtered_data[text_column].apply(preprocess_text) # Show sample of preprocessed text st.write("Sample preprocessed text:") sample_preprocessed = filtered_data['processed_text'].head(3).tolist() for i, text in enumerate(sample_preprocessed): st.text(f"Processed Text {i + 1}: {text[:200]}...") # Store preprocessed data st.session_state.data = filtered_data st.success("Text preprocessing completed!") else: st.warning("No text columns found in the selected columns.") else: st.warning("Please select at least one column.") else: st.error(f"Dataset validation failed: {validation_result['reason']}") st.warning("Please upload a valid dataset.") except Exception as e: st.error(f"Error loading data: {e}") st.warning("Please upload a valid CSV file.") else: st.info("Please upload a CSV file to begin preprocessing.") # Model Training page elif page == "Model Training": st.header("Model Training") # Check if data is available if st.session_state.data is not None: # Display data info st.subheader("Dataset Information") st.write(f"Number of samples: {len(st.session_state.data)}") if 'processed_text' in st.session_state.data.columns: st.write("Text preprocessing: Done") else: st.warning("Text preprocessing is not done. Please preprocess the data first.") # Model training options st.subheader("Training Options") # Select target column numerical_columns = [col for col in st.session_state.data.columns if st.session_state.data[col].dtype != 'object'] if numerical_columns: target_column = st.selectbox("Select target column", numerical_columns) # Set training parameters st.write("Training Parameters:") col1, col2 = st.columns(2) with col1: test_size = st.slider("Test Size", 0.1, 0.5, 0.2, 0.05) with col2: random_state = st.number_input("Random State", 0, 100, 42, 1) # Model selection model_type = st.radio( "Select model type", ["Standard Model", "Lightweight Model"] ) # Train model button if st.button("Train Model"): with st.spinner("Training model..."): # Check if processed text is available if 'processed_text' in st.session_state.data.columns: try: if model_type == "Standard Model": # Train standard model X_train = st.session_state.data['processed_text'] y_train = st.session_state.data[[target_column]] model = train_model(X_train, y_train, 'logistic_regression') # Store model in session state st.session_state.model = model st.session_state.vectorizer = None # Vectorizer is part of the pipeline st.success("Model training completed!") else: # Train lightweight model lightweight_model = train_lightweight_model( st.session_state.data, 'processed_text', target_column ) # Store lightweight model in session state st.session_state.lightweight_model = lightweight_model st.success("Lightweight model training completed!") except Exception as e: st.error(f"Error training model: {e}") else: st.error("Processed text not found. Please preprocess the data first.") else: st.warning("No numerical columns found in the dataset. Please ensure you have target columns.") else: st.info("Please upload and preprocess data before training a model.") # Model Evaluation page elif page == "Model Evaluation": st.header("Model Evaluation") # Check if model is available model_available = st.session_state.model is not None or st.session_state.lightweight_model is not None if model_available: # Display model info st.subheader("Model Information") if st.session_state.model is not None: st.write("Standard model is trained and ready.") if st.session_state.lightweight_model is not None: st.write("Lightweight model is trained and ready.") # Select model to evaluate model_choice = None if st.session_state.model is not None and st.session_state.lightweight_model is not None: model_choice = st.radio( "Select model to evaluate", ["Standard Model", "Lightweight Model"] ) # Model evaluation st.subheader("Evaluation Options") # Set evaluation parameters st.write("Evaluation Parameters:") col1, col2 = st.columns(2) with col1: test_size = st.slider("Test Size (Evaluation)", 0.1, 0.5, 0.2, 0.05) with col2: random_state = st.number_input("Random State (Evaluation)", 0, 100, 42, 1) # Target column selection if st.session_state.data is not None: numerical_columns = [col for col in st.session_state.data.columns if st.session_state.data[col].dtype != 'object'] if numerical_columns: target_column = st.selectbox("Select target column for evaluation", numerical_columns) # Evaluate model button if st.button("Evaluate Model"): with st.spinner("Evaluating model..."): try: # Determine which model to evaluate model_to_evaluate = None if model_choice == "Lightweight Model" or ( model_choice is None and st.session_state.model is None): model_to_evaluate = st.session_state.lightweight_model else: model_to_evaluate = st.session_state.model # 1️⃣ Evaluate model X_test = st.session_state.data['processed_text'] y_test = st.session_state.data[[target_column]] accuracy, roc_auc, predictions, pred_probs, fpr, tpr = evaluate_model(model_to_evaluate, X_test, y_test) # 2️⃣ Calculate additional metrics precision = precision_score(y_test, predictions, average='weighted', zero_division=0) recall = recall_score(y_test, predictions, average='weighted', zero_division=0) f1 = f1_score(y_test, predictions, average='weighted', zero_division=0) conf_matrix = confusion_matrix(y_test, predictions) classification_rep = classification_report(y_test, predictions, zero_division=0) # 3️⃣ Display evaluation metrics st.subheader("Evaluation Results") metrics_df = pd.DataFrame({ 'Metric': ['Accuracy', 'Precision', 'Recall', 'F1 Score', 'ROC AUC'], 'Value': [accuracy, precision, recall, f1, roc_auc] }) st.table(metrics_df) # 4️⃣ Confusion Matrix st.subheader("Confusion Matrix") fig, ax = plt.subplots(figsize=(8, 6)) sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', ax=ax, cbar=False, annot_kws={"size": 16}) plt.xlabel('Predicted') plt.ylabel('Actual') plt.title('Confusion Matrix') st.pyplot(fig) # 5️⃣ ROC Curve st.subheader("ROC Curve") if fpr is not None and tpr is not None: fig, ax = plt.subplots(figsize=(8, 6)) ax.plot(fpr, tpr, label=f'ROC Curve (AUC = {roc_auc:.2f})') ax.plot([0, 1], [0, 1], 'k--') ax.set_xlabel('False Positive Rate') ax.set_ylabel('True Positive Rate') ax.set_title('ROC Curve') ax.legend(loc="lower right") st.pyplot(fig) else: st.info("ROC curve is not available for multi-label classification.") # 6️⃣ Classification Report st.subheader("Classification Report") st.text(classification_rep) # 7️⃣ Bias Detection Report st.subheader("Bias Detection Report") if 'comment_text' in st.session_state.data.columns: bias_summary = bias_report( st.session_state.data[["comment_text"]].reset_index(drop=True), y_test.reset_index(drop=True), predictions, "comment_text" ) st.markdown(bias_summary) else: st.info("No comment_text column found for bias analysis.") except Exception as e: st.error(f"Error evaluating model: {e}") else: st.warning("No numerical columns found in the dataset. Please ensure you have target columns.") else: st.warning("Dataset not available. Please upload and preprocess data first.") # Model download st.subheader("Model Download") # Create download button model_to_download = None if model_choice == "Lightweight Model" or (model_choice is None and st.session_state.model is None): model_to_download = st.session_state.lightweight_model else: model_to_download = st.session_state.model if model_to_download is not None: # Determine the appropriate filename based on model type filename = "lightweight_model.pkl" if model_choice == "Lightweight Model" or (model_choice is None and st.session_state.model is None) else "standard_model.pkl" download_link = get_model_download_link(model_to_download, filename) st.markdown(download_link, unsafe_allow_html=True) else: st.info("Please train a model before evaluation.") # Prediction page elif page == "Prediction": st.header("Prediction") # Check if model is available model_available = st.session_state.model is not None or st.session_state.lightweight_model is not None if model_available: # Display model info st.subheader("Model Information") if st.session_state.model is not None: st.write("Standard model is trained and ready.") if st.session_state.lightweight_model is not None: st.write("Lightweight model is trained and ready.") # Select model to use model_choice = None if st.session_state.model is not None and st.session_state.lightweight_model is not None: model_choice = st.radio( "Select model for prediction", ["Standard Model", "Lightweight Model"] ) # Determine which model to use model_to_use = None if model_choice == "Lightweight Model" or (model_choice is None and st.session_state.model is None): model_to_use = st.session_state.lightweight_model else: model_to_use = st.session_state.model # Prediction options st.subheader("Make Predictions") prediction_type = st.radio( "Select prediction type", ["Single Comment", "Multiple Comments"] ) # Option to use BERT model use_bert = False if st.session_state.bert_model is not None: use_bert = st.checkbox("Include sentiment analysis with BERT") # Single comment prediction if prediction_type == "Single Comment": comment = st.text_area("Enter a comment to classify:") if st.button("Classify Comment"): if comment: with st.spinner("Classifying comment..."): st.write("DEBUG: bert_model in session_state before use:", st.session_state.bert_model) sentiment_model = st.session_state.bert_model if use_bert and st.session_state.bert_model is not None else None result = moderate_comment(comment, model_to_use, sentiment_model) display_classification_result(result) else: st.warning("Please enter a comment to classify.") # Multiple comments prediction else: # File upload for prediction uploaded_file = st.file_uploader("Upload a CSV file with comments", type="csv") if uploaded_file is not None: try: # Load data pred_data = pd.read_csv(uploaded_file) # Display data st.subheader("Uploaded Data") st.write(pred_data.head()) # Select text column text_columns = [col for col in pred_data.columns if pred_data[col].dtype == 'object'] if text_columns: text_column = st.selectbox("Select text column for prediction", text_columns) # Batch prediction button if st.button("Run Batch Prediction"): with st.spinner("Classifying comments..."): try: # Preprocess text pred_data['processed_text'] = pred_data[text_column].apply(preprocess_text) # Run prediction sentiment_model = st.session_state.bert_model if use_bert else None predictions = extract_predictions(pred_data, text_column, model_to_use, sentiment_model) # Store predictions st.session_state.predictions = predictions # Display results st.subheader("Prediction Results") st.write(predictions.head()) # Summary st.subheader("Summary") toxic_count = predictions['is_toxic'].sum() total_count = len(predictions) toxic_percentage = (toxic_count / total_count) * 100 st.write(f"Total comments: {total_count}") st.write(f"Toxic comments: {toxic_count} ({toxic_percentage:.2f}%)") st.write( f"Non-toxic comments: {total_count - toxic_count} ({100 - toxic_percentage:.2f}%)") # Download predictions if not predictions.empty: csv = predictions.to_csv(index=False) b64 = base64.b64encode(csv.encode()).decode() href = f'Download Predictions CSV' st.markdown(href, unsafe_allow_html=True) except Exception as e: st.error(f"Error during prediction: {e}") else: st.warning("No text columns found in the uploaded file.") except Exception as e: st.error(f"Error loading data: {e}") st.warning("Please upload a valid CSV file.") else: st.info("Please upload a CSV file with comments for batch prediction.") else: st.info("Please train a model before making predictions.") # Visualization page elif page == "Visualization": st.header("Visualization") # Check if data is available if st.session_state.data is not None: # Data visualization st.subheader("Data Visualization") # Select visualization type viz_type = st.selectbox( "Select visualization type", ["Toxicity Distribution", "Comment Length Distribution", "Word Cloud", "Correlation Matrix"] ) # Toxicity Distribution if viz_type == "Toxicity Distribution": # Check if there are label columns label_columns = [col for col in st.session_state.data.columns if col in [ "toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate" ]] if label_columns: st.write("Toxicity Distribution:") # Plot toxicity distribution fig = plot_toxicity_distribution(st.session_state.data, label_columns) st.pyplot(fig) else: st.warning("No toxicity label columns found in the dataset.") # Comment Length Distribution elif viz_type == "Comment Length Distribution": # Check if there is text column text_columns = [col for col in st.session_state.data.columns if st.session_state.data[col].dtype == 'object'] if text_columns: text_column = st.selectbox("Select text column", text_columns) # Calculate comment lengths st.session_state.data['comment_length'] = st.session_state.data[text_column].apply( lambda x: len(str(x))) # Plot comment length distribution fig, ax = plt.subplots(figsize=(10, 6)) sns.histplot(st.session_state.data['comment_length'], bins=50, kde=True, ax=ax) plt.xlabel('Comment Length') plt.ylabel('Frequency') plt.title('Comment Length Distribution') st.pyplot(fig) # Statistics st.write("Comment Length Statistics:") st.write(st.session_state.data['comment_length'].describe()) else: st.warning("No text columns found in the dataset.") # Word Cloud elif viz_type == "Word Cloud": # Check if there is processed text if 'processed_text' in st.session_state.data.columns: try: from wordcloud import WordCloud # Create word cloud st.write("Word Cloud Visualization:") # Combine all processed text all_text = ' '.join(st.session_state.data['processed_text'].tolist()) # Generate word cloud wordcloud = WordCloud( width=800, height=400, background_color='white', max_words=200, contour_width=3, contour_color='steelblue' ).generate(all_text) # Display word cloud fig, ax = plt.subplots(figsize=(10, 6)) ax.imshow(wordcloud, interpolation='bilinear') ax.axis('off') plt.tight_layout() st.pyplot(fig) except ImportError: st.error("WordCloud package is not installed. Please install it to use this feature.") else: st.warning("Processed text not found. Please preprocess the data first.") # Correlation Matrix elif viz_type == "Correlation Matrix": # Get numerical columns numerical_columns = [col for col in st.session_state.data.columns if st.session_state.data[col].dtype != 'object'] if len(numerical_columns) > 1: # Select columns for correlation selected_columns = st.multiselect( "Select columns for correlation matrix", options=numerical_columns, default=[col for col in numerical_columns if col in [ "toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate" ]] ) if selected_columns and len(selected_columns) > 1: # Calculate correlation correlation = st.session_state.data[selected_columns].corr() # Plot correlation matrix fig, ax = plt.subplots(figsize=(10, 8)) sns.heatmap( correlation, annot=True, cmap='coolwarm', ax=ax, cbar=True, fmt='.2f', linewidths=0.5 ) plt.title('Correlation Matrix') st.pyplot(fig) else: st.warning("Please select at least two columns for correlation matrix.") else: st.warning("Not enough numerical columns for correlation analysis.") # Add more visualization types as needed # Prediction visualization if st.session_state.predictions is not None: st.subheader("Prediction Visualization") # Distribution of predictions st.write("Distribution of Predictions:") # Plot prediction distribution fig, ax = plt.subplots(figsize=(10, 6)) if 'toxicity_score' in st.session_state.predictions.columns: sns.histplot(st.session_state.predictions['toxicity_score'], bins=20, kde=True, ax=ax) plt.xlabel('Toxicity Score') plt.ylabel('Frequency') plt.title('Distribution of Toxicity Scores') st.pyplot(fig) # Toxicity threshold analysis st.write("Toxicity Threshold Analysis:") threshold = st.slider("Toxicity Threshold", 0.0, 1.0, 0.5, 0.05) # Calculate metrics at different thresholds st.session_state.predictions['is_toxic_at_threshold'] = st.session_state.predictions[ 'toxicity_score'] > threshold toxic_at_threshold = st.session_state.predictions['is_toxic_at_threshold'].sum() total_predictions = len(st.session_state.predictions) st.write(f"Threshold: {threshold}") st.write( f"Toxic comments: {toxic_at_threshold} ({toxic_at_threshold / total_predictions * 100:.2f}%)") st.write( f"Non-toxic comments: {total_predictions - toxic_at_threshold} ({(total_predictions - toxic_at_threshold) / total_predictions * 100:.2f}%)") else: st.warning("Toxicity scores not found in predictions.") else: st.info("Please upload and preprocess data for visualization.") if __name__ == "__main__": main()