|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
import re |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import nltk |
|
from nltk.corpus import stopwords |
|
from nltk.stem.snowball import SnowballStemmer |
|
import pickle |
|
from transformers import pipeline as hf_pipeline |
|
from sklearn.utils.multiclass import type_of_target |
|
import io |
|
import base64 |
|
from sklearn.model_selection import train_test_split |
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
from sklearn.pipeline import Pipeline |
|
from sklearn.multiclass import OneVsRestClassifier |
|
from sklearn.linear_model import LogisticRegression |
|
from sklearn.naive_bayes import MultinomialNB |
|
from sklearn.metrics import roc_auc_score, accuracy_score, classification_report |
|
from textblob import TextBlob |
|
import warnings |
|
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix |
|
from sklearn.metrics import roc_curve |
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
try: |
|
nltk.data.find('corpora/stopwords') |
|
except LookupError: |
|
nltk.download('stopwords') |
|
|
|
|
|
stemmer = SnowballStemmer('english') |
|
stop_words_set = set(stopwords.words('english')) |
|
|
|
|
|
|
|
def remove_stopwords(text): |
|
return " ".join([word for word in str(text).split() if word.lower() not in stop_words_set]) |
|
|
|
|
|
def train_lightweight_model(data, text_column, label_column): |
|
""" |
|
Train a lightweight model for toxicity detection |
|
|
|
Args: |
|
data: DataFrame containing the data |
|
text_column: Column name for text data |
|
label_column: Column name for label data |
|
|
|
Returns: |
|
Trained model and vectorizer |
|
""" |
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
from sklearn.linear_model import LogisticRegression |
|
from sklearn.pipeline import Pipeline |
|
|
|
|
|
data['processed_text'] = data[text_column].apply(preprocess_text) |
|
|
|
|
|
model = Pipeline([ |
|
('tfidf', TfidfVectorizer(max_features=5000, ngram_range=(1, 2))), |
|
('clf', LogisticRegression(random_state=42, max_iter=1000)) |
|
]) |
|
|
|
|
|
model.fit(data['processed_text'], data[label_column]) |
|
|
|
return model |
|
|
|
|
|
def load_bert_model(): |
|
""" |
|
Load a pre-trained BERT model for sentiment analysis |
|
|
|
Returns: |
|
Loaded model |
|
""" |
|
try: |
|
|
|
sentiment_analyzer = hf_pipeline("sentiment-analysis") |
|
st.success("BERT model loaded successfully!") |
|
return sentiment_analyzer |
|
except Exception as e: |
|
st.error(f"Error loading BERT model: {e}") |
|
return None |
|
|
|
|
|
def clean_text(text): |
|
text = str(text).lower() |
|
text = re.sub(r"what's", "what is ", text) |
|
text = re.sub(r"\'s", " ", text) |
|
text = re.sub(r"\'ve", " have ", text) |
|
text = re.sub(r"can't", "can not ", text) |
|
text = re.sub(r"n't", " not ", text) |
|
text = re.sub(r"i'm", "i am ", text) |
|
text = re.sub(r"\'re", " are ", text) |
|
text = re.sub(r"\'d", " would ", text) |
|
text = re.sub(r"\'ll", " will ", text) |
|
text = re.sub(r"\'scuse", " excuse ", text) |
|
text = re.sub(r'\W', ' ', text) |
|
text = re.sub(r'\s+', ' ', text).strip() |
|
return text |
|
|
|
|
|
def stemming(sentence): |
|
return " ".join([stemmer.stem(word) for word in str(sentence).split()]) |
|
|
|
|
|
def preprocess_text(text): |
|
text = remove_stopwords(text) |
|
text = clean_text(text) |
|
text = stemming(text) |
|
return text |
|
|
|
|
|
|
|
def get_sentiment(text): |
|
score = TextBlob(text).sentiment.polarity |
|
if score > 0: |
|
return "Positive", score |
|
elif score < 0: |
|
return "Negative", score |
|
else: |
|
return "Neutral", score |
|
|
|
|
|
|
|
def moderate_text(text, predictions, threshold_moderate=0.5, threshold_delete=0.8): |
|
|
|
if len(predictions) == 2: |
|
toxic_prob = predictions[1] |
|
if toxic_prob >= threshold_delete: |
|
return "*** COMMENT DELETED DUE TO HIGH TOXICITY ***", "delete" |
|
elif toxic_prob >= threshold_moderate: |
|
|
|
toxic_words = ["stupid", "idiot", "dumb", "hate", "sucks", "terrible", |
|
"awful", "garbage", "trash", "pathetic", "ridiculous"] |
|
|
|
words = text.split() |
|
moderated_words = [] |
|
|
|
for word in words: |
|
|
|
clean_word = re.sub(r'[^\w\s]', '', word.lower()) |
|
|
|
|
|
if clean_word in toxic_words: |
|
|
|
moderated_words.append("[inappropriate]") |
|
else: |
|
moderated_words.append(word) |
|
|
|
return " ".join(moderated_words), "moderate" |
|
else: |
|
return text, "keep" |
|
else: |
|
|
|
if any(pred >= threshold_delete for pred in predictions): |
|
return "*** COMMENT DELETED DUE TO HIGH TOXICITY ***", "delete" |
|
elif any(pred >= threshold_moderate for pred in predictions): |
|
|
|
toxic_words = ["stupid", "idiot", "dumb", "hate", "sucks", "terrible", |
|
"awful", "garbage", "trash", "pathetic", "ridiculous"] |
|
|
|
words = text.split() |
|
moderated_words = [] |
|
|
|
for word in words: |
|
|
|
clean_word = re.sub(r'[^\w\s]', '', word.lower()) |
|
|
|
|
|
if clean_word in toxic_words: |
|
|
|
moderated_words.append("[inappropriate]") |
|
else: |
|
moderated_words.append(word) |
|
|
|
return " ".join(moderated_words), "moderate" |
|
else: |
|
return text, "keep" |
|
|
|
|
|
|
|
def train_model(X_train, y_train, model_type='logistic_regression'): |
|
st.write("Training model...") |
|
|
|
|
|
label_columns = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] |
|
|
|
|
|
for col in label_columns: |
|
if col not in y_train.columns: |
|
y_train[col] = 0 |
|
|
|
|
|
y_train = y_train[label_columns] |
|
|
|
if model_type == 'logistic_regression': |
|
pipeline = Pipeline([ |
|
('tfidf', TfidfVectorizer(stop_words='english', max_features=50000)), |
|
('clf', OneVsRestClassifier(LogisticRegression(max_iter=1000), n_jobs=-1)) |
|
]) |
|
else: |
|
pipeline = Pipeline([ |
|
('tfidf', TfidfVectorizer(stop_words='english', max_features=50000)), |
|
('clf', OneVsRestClassifier(MultinomialNB(), n_jobs=-1)) |
|
]) |
|
|
|
pipeline.fit(X_train, y_train) |
|
|
|
return pipeline |
|
|
|
|
|
|
|
|
|
def evaluate_model(pipeline, X_test, y_test): |
|
""" |
|
Evaluates the given trained pipeline on test data. |
|
Returns: |
|
accuracy: Accuracy score |
|
roc_auc: ROC AUC score |
|
predictions: Predicted labels |
|
pred_probs: Predicted probabilities |
|
fpr: False Positive Rate array (for ROC curve) |
|
tpr: True Positive Rate array (for ROC curve) |
|
""" |
|
|
|
|
|
predictions = pipeline.predict(X_test) |
|
pred_probs = pipeline.predict_proba(X_test) |
|
|
|
if isinstance(pred_probs, list) and len(pred_probs) == 1: |
|
pred_probs = pred_probs[0] |
|
|
|
|
|
y_type = type_of_target(y_test) |
|
pred_type = type_of_target(predictions) |
|
|
|
if y_type != pred_type: |
|
if y_type == "multilabel-indicator" and pred_type == "binary": |
|
|
|
predictions = np.array([[pred] * y_test.shape[1] for pred in predictions]) |
|
elif y_type == "binary" and pred_type == "multilabel-indicator": |
|
|
|
predictions = predictions[:, 0] |
|
|
|
|
|
accuracy = accuracy_score(y_test, predictions) |
|
|
|
|
|
try: |
|
if len(y_test.shape) > 1 and y_test.shape[1] > 1: |
|
|
|
roc_auc_sum = 0 |
|
valid_labels = 0 |
|
for i in range(y_test.shape[1]): |
|
try: |
|
roc_auc_sum += roc_auc_score(y_test.iloc[:, i], pred_probs[:, i]) |
|
valid_labels += 1 |
|
except Exception: |
|
continue |
|
roc_auc = roc_auc_sum / valid_labels if valid_labels > 0 else 0.0 |
|
else: |
|
|
|
roc_auc = roc_auc_score(y_test, pred_probs[:, 1] if pred_probs.ndim > 1 and pred_probs.shape[1] > 1 else pred_probs) |
|
except Exception as e: |
|
print(f"Warning: Could not compute ROC AUC - {e}") |
|
roc_auc = 0.0 |
|
|
|
|
|
try: |
|
if len(y_test.shape) == 1 or (len(y_test.shape) > 1 and y_test.shape[1] == 1): |
|
fpr, tpr, _ = roc_curve( |
|
y_test, |
|
pred_probs[:, 1] if pred_probs.ndim > 1 and pred_probs.shape[1] > 1 else pred_probs |
|
) |
|
else: |
|
fpr, tpr = None, None |
|
except Exception as e: |
|
print(f"Warning: Could not compute ROC Curve - {e}") |
|
fpr, tpr = None, None |
|
|
|
return accuracy, roc_auc, predictions, pred_probs, fpr, tpr |
|
|
|
|
|
|
|
|
|
def get_model_download_link(model, filename): |
|
model_bytes = pickle.dumps(model) |
|
b64 = base64.b64encode(model_bytes).decode() |
|
href = f'<a href="data:file/pickle;base64,{b64}" download="{filename}">Download Trained Model</a>' |
|
return href |
|
|
|
|
|
|
|
def plot_toxicity_distribution(df, toxicity_columns): |
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
|
|
x = df[toxicity_columns].sum() |
|
sns.barplot(x=x.index, y=x.values, alpha=0.8, palette='viridis', ax=ax) |
|
|
|
plt.title('Toxicity Distribution') |
|
plt.ylabel('Count') |
|
plt.xlabel('Toxicity Category') |
|
plt.xticks(rotation=45) |
|
|
|
return fig |
|
|
|
|
|
|
|
def show_sample_data_format(): |
|
st.subheader("Sample Data Format") |
|
|
|
|
|
sample_data = { |
|
'comment_text': [ |
|
"This is a normal comment.", |
|
"This is a toxic comment you idiot!", |
|
"You're all worthless and should die.", |
|
"I respectfully disagree with your point." |
|
], |
|
'toxic': [0, 1, 1, 0], |
|
'severe_toxic': [0, 0, 1, 0], |
|
'obscene': [0, 1, 0, 0], |
|
'threat': [0, 0, 1, 0], |
|
'insult': [0, 1, 1, 0], |
|
'identity_hate': [0, 0, 0, 0] |
|
} |
|
|
|
sample_df = pd.DataFrame(sample_data) |
|
st.dataframe(sample_df) |
|
|
|
|
|
csv = sample_df.to_csv(index=False) |
|
b64 = base64.b64encode(csv.encode()).decode() |
|
href = f'<a href="data:file/csv;base64,{b64}" download="sample_toxic_data.csv">Download Sample CSV</a>' |
|
st.markdown(href, unsafe_allow_html=True) |
|
|
|
st.info(""" |
|
Your CSV file should contain: |
|
1. A column with comment text |
|
2. One or more columns with binary values (0 or 1) for each toxicity category |
|
""") |
|
|
|
|
|
|
|
def validate_dataset(df, comment_column, toxicity_columns): |
|
issues = [] |
|
|
|
|
|
if comment_column not in df.columns: |
|
issues.append(f"Comment column '{comment_column}' not found in the dataset") |
|
|
|
|
|
missing_columns = [col for col in toxicity_columns if col not in df.columns] |
|
if missing_columns: |
|
issues.append(f"Missing toxicity columns: {', '.join(missing_columns)}") |
|
|
|
|
|
for col in toxicity_columns: |
|
if col in df.columns: |
|
|
|
if not pd.api.types.is_numeric_dtype(df[col]): |
|
issues.append(f"Column '{col}' contains non-numeric values") |
|
else: |
|
|
|
invalid_values = df[col].dropna().apply(lambda x: x not in [0, 1, 0.0, 1.0]) |
|
if invalid_values.any(): |
|
issues.append(f"Column '{col}' contains values other than 0 and 1") |
|
|
|
|
|
if df.empty: |
|
issues.append("Dataset is empty") |
|
elif df[comment_column].isna().all(): |
|
issues.append("Comment column contains no data") |
|
|
|
return issues |
|
|
|
|
|
|
|
def extract_predictions(predictions_proba, toxicity_categories): |
|
""" |
|
Helper function to extract probabilities from model output, |
|
handling different output formats. |
|
""" |
|
|
|
if st.session_state.debug_mode: |
|
st.write(f"Predictions type: {type(predictions_proba)}") |
|
st.write( |
|
f"Predictions shape/length: {np.shape(predictions_proba) if hasattr(predictions_proba, 'shape') else len(predictions_proba)}") |
|
|
|
|
|
if isinstance(predictions_proba, list) and len(predictions_proba) == len(toxicity_categories): |
|
return [pred_array[:, 1][0] if pred_array.shape[1] > 1 else pred_array[0] for pred_array in predictions_proba] |
|
|
|
|
|
elif isinstance(predictions_proba, list) and len(predictions_proba) == 1: |
|
pred_array = predictions_proba[0] |
|
|
|
if len(pred_array.shape) == 2 and pred_array.shape[1] == len(toxicity_categories): |
|
return pred_array[0] |
|
|
|
elif len(pred_array.shape) == 2 and pred_array.shape[1] == 2: |
|
return np.array([pred_array[0, 1]]) |
|
|
|
|
|
elif isinstance(predictions_proba, np.ndarray): |
|
|
|
if len(predictions_proba.shape) == 2 and predictions_proba.shape[1] == len(toxicity_categories): |
|
return predictions_proba[0] |
|
|
|
elif len(predictions_proba.shape) == 2 and predictions_proba.shape[1] == 2: |
|
|
|
return np.array([predictions_proba[0, 1]]) |
|
|
|
|
|
|
|
if isinstance(predictions_proba, list) and len(predictions_proba) == 1: |
|
single_prob = predictions_proba[0] |
|
if hasattr(single_prob, 'shape') and len(single_prob.shape) == 2 and single_prob.shape[1] == 2: |
|
|
|
return np.full(len(toxicity_categories), single_prob[0, 1]) |
|
|
|
|
|
st.warning(f"Unexpected prediction format. Creating default predictions.") |
|
return np.zeros(len(toxicity_categories)) |
|
|
|
|
|
def display_classification_result(result): |
|
st.subheader("Classification Result") |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.markdown("**Original Text**") |
|
st.code(result["original_text"], language="text") |
|
with col2: |
|
st.markdown("**Moderated Text**") |
|
st.code(result["moderated_text"], language="text") |
|
|
|
|
|
action = result["action"] |
|
if action == "keep": |
|
st.success("✅ This comment is allowed (Non-toxic).") |
|
elif action == "moderate": |
|
st.warning("⚠️ This comment is moderated (Potentially toxic).") |
|
elif action == "delete": |
|
st.error("🚫 This comment is deleted (Highly toxic).") |
|
|
|
|
|
st.markdown("**Toxicity Scores:**") |
|
score_cols = st.columns(len(result["toxicity_scores"])) |
|
for i, (label, score) in enumerate(result["toxicity_scores"].items()): |
|
score_cols[i].metric(label.capitalize(), f"{score:.2%}") |
|
|
|
|
|
if "sentiment" in result: |
|
st.markdown("**Sentiment Analysis:**") |
|
st.info(f"{result['sentiment']['label']} (score: {result['sentiment']['score']:.2%})") |
|
|
|
|
|
def moderate_comment(comment, model, sentiment_model=None): |
|
""" |
|
Moderate a single comment using the trained model and optionally BERT sentiment analysis. |
|
|
|
Args: |
|
comment: The comment text to moderate |
|
model: The trained model to use for toxicity detection |
|
sentiment_model: Optional BERT model for sentiment analysis |
|
|
|
Returns: |
|
Dictionary containing moderation results |
|
""" |
|
|
|
processed_text = preprocess_text(comment) |
|
|
|
|
|
predictions = model.predict_proba([processed_text])[0] |
|
|
|
|
|
sentiment = None |
|
if sentiment_model: |
|
sentiment = sentiment_model(comment)[0] |
|
|
|
|
|
moderated_text, action = moderate_text(comment, predictions) |
|
|
|
|
|
result = { |
|
"original_text": comment, |
|
"moderated_text": moderated_text, |
|
"action": action, |
|
"toxicity_scores": {} |
|
} |
|
|
|
|
|
if len(predictions) == 2: |
|
result["toxicity_scores"] = { |
|
"toxic": float(predictions[1]), |
|
"non_toxic": float(predictions[0]) |
|
} |
|
else: |
|
|
|
categories = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"] |
|
|
|
|
|
for i, category in enumerate(categories): |
|
if i < len(predictions): |
|
result["toxicity_scores"][category] = float(predictions[i]) |
|
|
|
if sentiment: |
|
result["sentiment"] = { |
|
"label": sentiment["label"], |
|
"score": float(sentiment["score"]) |
|
} |
|
|
|
return result |
|
|
|
|
|
|
|
def detect_subgroup(text): |
|
gender_keywords = ["he", "she", "him", "her", "man", "woman", "boy", "girl", "male", "female"] |
|
ethnicity_keywords = [ |
|
"asian", "black", "white", "hispanic", "latino", "indian", "african", "european", "arab", "jewish", "muslim", "christian" |
|
] |
|
text_lower = text.lower() |
|
subgroups = set() |
|
if any(word in text_lower for word in gender_keywords): |
|
subgroups.add("gender") |
|
if any(word in text_lower for word in ethnicity_keywords): |
|
subgroups.add("ethnicity") |
|
return list(subgroups) |
|
|
|
def bias_report(X, y_true, y_pred, text_column_name): |
|
|
|
|
|
|
|
|
|
results = [] |
|
for idx, row in X.iterrows(): |
|
subgroups = detect_subgroup(row[text_column_name]) |
|
if subgroups: |
|
for subgroup in subgroups: |
|
results.append({ |
|
"subgroup": subgroup, |
|
"is_toxic": int(y_pred[idx].sum() > 0) if len(y_pred.shape) > 1 else int(y_pred[idx] > 0) |
|
}) |
|
if not results: |
|
return "No sensitive subgroups detected in the evaluation set." |
|
df = pd.DataFrame(results) |
|
report = "" |
|
for subgroup in df["subgroup"].unique(): |
|
total = (df["subgroup"] == subgroup).sum() |
|
toxic = df[(df["subgroup"] == subgroup) & (df["is_toxic"] == 1)].shape[0] |
|
rate = toxic / total if total > 0 else 0 |
|
report += f"- **{subgroup.capitalize()}**: {toxic}/{total} ({rate:.1%}) flagged as toxic\n" |
|
return report |
|
|
|
|
|
|
|
def main(): |
|
st.set_page_config( |
|
page_title="Toxic Comment Classifier", |
|
page_icon="🧊", |
|
layout="wide", |
|
initial_sidebar_state="expanded", |
|
) |
|
|
|
col1, col2 = st.columns([1, 4]) |
|
|
|
with col1: |
|
st.image("logo.jpeg", width=100) |
|
|
|
with col2: |
|
st.title("Toxic Comment Classifier") |
|
|
|
|
|
|
|
if 'data' not in st.session_state: |
|
st.session_state.data = None |
|
if 'model' not in st.session_state: |
|
st.session_state.model = None |
|
if 'vectorizer' not in st.session_state: |
|
st.session_state.vectorizer = None |
|
if 'predictions' not in st.session_state: |
|
st.session_state.predictions = None |
|
if 'lightweight_model' not in st.session_state: |
|
st.session_state.lightweight_model = None |
|
if 'bert_model' not in st.session_state: |
|
st.session_state.bert_model = None |
|
|
|
|
|
st.sidebar.title("Navigation") |
|
page = st.sidebar.radio( |
|
"Select a page", |
|
["Home", "Data Preprocessing", "Model Training", "Model Evaluation", "Prediction", "Visualization"] |
|
) |
|
|
|
|
|
if page == "Home": |
|
|
|
st.header("Home") |
|
st.write(""" |
|
Welcome to the Toxic Comment Classifier application. This tool helps you to: |
|
1. Upload and preprocess data |
|
2. Train a machine learning model to detect toxic comments |
|
3. Evaluate model performance |
|
4. Make predictions on new data |
|
5. Visualize results |
|
|
|
Please use the sidebar navigation to get started. |
|
""") |
|
|
|
|
|
if st.sidebar.checkbox("Use BERT for Sentiment Analysis"): |
|
st.subheader("BERT-Based Sentiment Analysis") |
|
st.write("This option uses a pre-trained BERT model for advanced sentiment analysis.") |
|
|
|
if st.button("Load BERT Model"): |
|
with st.spinner("Loading BERT model..."): |
|
st.session_state.bert_model = load_bert_model() |
|
st.write("DEBUG: bert_model in session_state after loading:", st.session_state.bert_model) |
|
|
|
|
|
st.subheader("Sample Data Format") |
|
show_sample_data_format() |
|
|
|
|
|
st.subheader("Try Comment Moderation") |
|
comment = st.text_area("Enter a comment to moderate:") |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
use_default_model = st.checkbox("Use built-in model for demo", value=True) |
|
|
|
with col2: |
|
use_bert = st.checkbox("Use BERT model for sentiment (if loaded)", value=False) |
|
|
|
if st.button("Moderate Comment"): |
|
if comment: |
|
with st.spinner("Analyzing comment..."): |
|
st.write("DEBUG: bert_model in session_state before use:", st.session_state.bert_model) |
|
sentiment_model = st.session_state.bert_model if use_bert and st.session_state.bert_model is not None else None |
|
|
|
if use_default_model or st.session_state.model or st.session_state.lightweight_model: |
|
model_to_use = None |
|
if st.session_state.model: |
|
model_to_use = st.session_state.model |
|
elif st.session_state.lightweight_model: |
|
model_to_use = st.session_state.lightweight_model |
|
|
|
result = moderate_comment(comment, model_to_use, sentiment_model) |
|
|
|
display_classification_result(result) |
|
else: |
|
st.error("No model available. Please train a model first or enable the demo model.") |
|
else: |
|
st.warning("Please enter a comment to moderate.") |
|
|
|
|
|
elif page == "Data Preprocessing": |
|
st.header("Data Preprocessing") |
|
|
|
|
|
st.subheader("Upload Dataset") |
|
uploaded_file = st.file_uploader("Choose a CSV file", type="csv") |
|
|
|
if uploaded_file is not None: |
|
try: |
|
|
|
data = pd.read_csv(uploaded_file) |
|
|
|
|
|
st.subheader("Raw Data") |
|
st.write(data.head()) |
|
|
|
|
|
validation_result = validate_dataset(data, 'comment_text', ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']) |
|
|
|
if not validation_result: |
|
st.success("Dataset is valid!") |
|
|
|
|
|
st.session_state.data = data |
|
|
|
|
|
st.subheader("Data Cleaning") |
|
st.write("Select columns to include in the analysis:") |
|
|
|
|
|
all_columns = data.columns.tolist() |
|
default_selected = ["comment_text", "toxic", "severe_toxic", "obscene", "threat", "insult", |
|
"identity_hate"] |
|
default_selected = [col for col in default_selected if col in all_columns] |
|
|
|
selected_columns = st.multiselect( |
|
"Select columns", |
|
options=all_columns, |
|
default=default_selected |
|
) |
|
|
|
if selected_columns: |
|
|
|
filtered_data = data[selected_columns] |
|
|
|
|
|
st.subheader("Filtered Data") |
|
st.write(filtered_data.head()) |
|
|
|
|
|
st.subheader("Data Statistics") |
|
st.write(filtered_data.describe()) |
|
|
|
|
|
st.subheader("Missing Values") |
|
missing_values = filtered_data.isnull().sum() |
|
st.write(missing_values) |
|
|
|
|
|
if missing_values.sum() > 0: |
|
st.warning("There are missing values in the dataset.") |
|
|
|
if st.button("Handle Missing Values"): |
|
|
|
text_columns = [col for col in selected_columns if filtered_data[col].dtype == 'object'] |
|
for col in text_columns: |
|
filtered_data[col] = filtered_data[col].fillna("") |
|
|
|
|
|
numerical_columns = [col for col in selected_columns if |
|
filtered_data[col].dtype != 'object'] |
|
for col in numerical_columns: |
|
filtered_data[col] = filtered_data[col].fillna(0) |
|
|
|
st.success("Missing values handled!") |
|
st.write(filtered_data.isnull().sum()) |
|
|
|
|
|
st.subheader("Text Preprocessing") |
|
|
|
|
|
text_columns = [col for col in selected_columns if filtered_data[col].dtype == 'object'] |
|
|
|
if text_columns: |
|
text_column = st.selectbox("Select text column for preprocessing", text_columns) |
|
|
|
|
|
st.write("Sample original text:") |
|
sample_texts = filtered_data[text_column].head(3).tolist() |
|
for i, text in enumerate(sample_texts): |
|
st.text(f"Text {i + 1}: {text[:200]}...") |
|
|
|
|
|
if st.button("Preprocess Text"): |
|
with st.spinner("Preprocessing text..."): |
|
filtered_data['processed_text'] = filtered_data[text_column].apply(preprocess_text) |
|
|
|
|
|
st.write("Sample preprocessed text:") |
|
sample_preprocessed = filtered_data['processed_text'].head(3).tolist() |
|
for i, text in enumerate(sample_preprocessed): |
|
st.text(f"Processed Text {i + 1}: {text[:200]}...") |
|
|
|
|
|
st.session_state.data = filtered_data |
|
st.success("Text preprocessing completed!") |
|
else: |
|
st.warning("No text columns found in the selected columns.") |
|
else: |
|
st.warning("Please select at least one column.") |
|
else: |
|
st.error(f"Dataset validation failed: {validation_result['reason']}") |
|
st.warning("Please upload a valid dataset.") |
|
|
|
except Exception as e: |
|
st.error(f"Error loading data: {e}") |
|
st.warning("Please upload a valid CSV file.") |
|
else: |
|
st.info("Please upload a CSV file to begin preprocessing.") |
|
|
|
|
|
elif page == "Model Training": |
|
st.header("Model Training") |
|
|
|
|
|
if st.session_state.data is not None: |
|
|
|
st.subheader("Dataset Information") |
|
st.write(f"Number of samples: {len(st.session_state.data)}") |
|
|
|
if 'processed_text' in st.session_state.data.columns: |
|
st.write("Text preprocessing: Done") |
|
else: |
|
st.warning("Text preprocessing is not done. Please preprocess the data first.") |
|
|
|
|
|
st.subheader("Training Options") |
|
|
|
|
|
numerical_columns = [col for col in st.session_state.data.columns if |
|
st.session_state.data[col].dtype != 'object'] |
|
|
|
if numerical_columns: |
|
target_column = st.selectbox("Select target column", numerical_columns) |
|
|
|
|
|
st.write("Training Parameters:") |
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
test_size = st.slider("Test Size", 0.1, 0.5, 0.2, 0.05) |
|
|
|
with col2: |
|
random_state = st.number_input("Random State", 0, 100, 42, 1) |
|
|
|
|
|
model_type = st.radio( |
|
"Select model type", |
|
["Standard Model", "Lightweight Model"] |
|
) |
|
|
|
|
|
if st.button("Train Model"): |
|
with st.spinner("Training model..."): |
|
|
|
if 'processed_text' in st.session_state.data.columns: |
|
try: |
|
if model_type == "Standard Model": |
|
|
|
X_train = st.session_state.data['processed_text'] |
|
y_train = st.session_state.data[[target_column]] |
|
model = train_model(X_train, y_train, 'logistic_regression') |
|
|
|
|
|
st.session_state.model = model |
|
st.session_state.vectorizer = None |
|
|
|
st.success("Model training completed!") |
|
else: |
|
|
|
lightweight_model = train_lightweight_model( |
|
st.session_state.data, |
|
'processed_text', |
|
target_column |
|
) |
|
|
|
|
|
st.session_state.lightweight_model = lightweight_model |
|
|
|
st.success("Lightweight model training completed!") |
|
|
|
except Exception as e: |
|
st.error(f"Error training model: {e}") |
|
else: |
|
st.error("Processed text not found. Please preprocess the data first.") |
|
else: |
|
st.warning("No numerical columns found in the dataset. Please ensure you have target columns.") |
|
else: |
|
st.info("Please upload and preprocess data before training a model.") |
|
|
|
|
|
elif page == "Model Evaluation": |
|
st.header("Model Evaluation") |
|
|
|
|
|
model_available = st.session_state.model is not None or st.session_state.lightweight_model is not None |
|
|
|
if model_available: |
|
|
|
st.subheader("Model Information") |
|
if st.session_state.model is not None: |
|
st.write("Standard model is trained and ready.") |
|
if st.session_state.lightweight_model is not None: |
|
st.write("Lightweight model is trained and ready.") |
|
|
|
|
|
model_choice = None |
|
if st.session_state.model is not None and st.session_state.lightweight_model is not None: |
|
model_choice = st.radio( |
|
"Select model to evaluate", |
|
["Standard Model", "Lightweight Model"] |
|
) |
|
|
|
|
|
st.subheader("Evaluation Options") |
|
|
|
|
|
st.write("Evaluation Parameters:") |
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
test_size = st.slider("Test Size (Evaluation)", 0.1, 0.5, 0.2, 0.05) |
|
|
|
with col2: |
|
random_state = st.number_input("Random State (Evaluation)", 0, 100, 42, 1) |
|
|
|
|
|
if st.session_state.data is not None: |
|
numerical_columns = [col for col in st.session_state.data.columns if |
|
st.session_state.data[col].dtype != 'object'] |
|
|
|
if numerical_columns: |
|
target_column = st.selectbox("Select target column for evaluation", numerical_columns) |
|
|
|
|
|
if st.button("Evaluate Model"): |
|
with st.spinner("Evaluating model..."): |
|
try: |
|
|
|
model_to_evaluate = None |
|
if model_choice == "Lightweight Model" or ( |
|
model_choice is None and st.session_state.model is None): |
|
model_to_evaluate = st.session_state.lightweight_model |
|
else: |
|
model_to_evaluate = st.session_state.model |
|
|
|
|
|
X_test = st.session_state.data['processed_text'] |
|
y_test = st.session_state.data[[target_column]] |
|
|
|
accuracy, roc_auc, predictions, pred_probs, fpr, tpr = evaluate_model(model_to_evaluate, |
|
X_test, y_test) |
|
|
|
|
|
precision = precision_score(y_test, predictions, average='weighted', zero_division=0) |
|
recall = recall_score(y_test, predictions, average='weighted', zero_division=0) |
|
f1 = f1_score(y_test, predictions, average='weighted', zero_division=0) |
|
conf_matrix = confusion_matrix(y_test, predictions) |
|
classification_rep = classification_report(y_test, predictions, zero_division=0) |
|
|
|
|
|
st.subheader("Evaluation Results") |
|
metrics_df = pd.DataFrame({ |
|
'Metric': ['Accuracy', 'Precision', 'Recall', 'F1 Score', 'ROC AUC'], |
|
'Value': [accuracy, precision, recall, f1, roc_auc] |
|
}) |
|
st.table(metrics_df) |
|
|
|
|
|
st.subheader("Confusion Matrix") |
|
fig, ax = plt.subplots(figsize=(8, 6)) |
|
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', ax=ax, cbar=False, |
|
annot_kws={"size": 16}) |
|
plt.xlabel('Predicted') |
|
plt.ylabel('Actual') |
|
plt.title('Confusion Matrix') |
|
st.pyplot(fig) |
|
|
|
|
|
st.subheader("ROC Curve") |
|
if fpr is not None and tpr is not None: |
|
fig, ax = plt.subplots(figsize=(8, 6)) |
|
ax.plot(fpr, tpr, label=f'ROC Curve (AUC = {roc_auc:.2f})') |
|
ax.plot([0, 1], [0, 1], 'k--') |
|
ax.set_xlabel('False Positive Rate') |
|
ax.set_ylabel('True Positive Rate') |
|
ax.set_title('ROC Curve') |
|
ax.legend(loc="lower right") |
|
st.pyplot(fig) |
|
else: |
|
st.info("ROC curve is not available for multi-label classification.") |
|
|
|
|
|
st.subheader("Classification Report") |
|
st.text(classification_rep) |
|
|
|
except Exception as e: |
|
st.error(f"Error evaluating model: {e}") |
|
|
|
|
|
st.subheader("Classification Report") |
|
st.text(classification_rep) |
|
|
|
|
|
|
|
st.subheader("Bias Detection Report") |
|
if 'comment_text' in st.session_state.data.columns: |
|
bias_summary = bias_report( |
|
st.session_state.data[["comment_text"]].reset_index(drop=True), |
|
y_test.reset_index(drop=True), |
|
predictions, |
|
"comment_text" |
|
) |
|
st.markdown(bias_summary) |
|
else: |
|
st.info("No comment_text column found for bias analysis.") |
|
|
|
|
|
|
|
except Exception as e: |
|
st.error(f"Error evaluating model: {e}") |
|
else: |
|
st.warning("No numerical columns found in the dataset. Please ensure you have target columns.") |
|
else: |
|
st.warning("Dataset not available. Please upload and preprocess data first.") |
|
|
|
|
|
st.subheader("Model Download") |
|
|
|
|
|
model_to_download = None |
|
if model_choice == "Lightweight Model" or (model_choice is None and st.session_state.model is None): |
|
model_to_download = st.session_state.lightweight_model |
|
else: |
|
model_to_download = st.session_state.model |
|
|
|
if model_to_download is not None: |
|
|
|
filename = "lightweight_model.pkl" if model_choice == "Lightweight Model" or (model_choice is None and st.session_state.model is None) else "standard_model.pkl" |
|
download_link = get_model_download_link(model_to_download, filename) |
|
st.markdown(download_link, unsafe_allow_html=True) |
|
else: |
|
st.info("Please train a model before evaluation.") |
|
|
|
|
|
elif page == "Prediction": |
|
st.header("Prediction") |
|
|
|
|
|
model_available = st.session_state.model is not None or st.session_state.lightweight_model is not None |
|
|
|
if model_available: |
|
|
|
st.subheader("Model Information") |
|
if st.session_state.model is not None: |
|
st.write("Standard model is trained and ready.") |
|
if st.session_state.lightweight_model is not None: |
|
st.write("Lightweight model is trained and ready.") |
|
|
|
|
|
model_choice = None |
|
if st.session_state.model is not None and st.session_state.lightweight_model is not None: |
|
model_choice = st.radio( |
|
"Select model for prediction", |
|
["Standard Model", "Lightweight Model"] |
|
) |
|
|
|
|
|
model_to_use = None |
|
if model_choice == "Lightweight Model" or (model_choice is None and st.session_state.model is None): |
|
model_to_use = st.session_state.lightweight_model |
|
else: |
|
model_to_use = st.session_state.model |
|
|
|
|
|
st.subheader("Make Predictions") |
|
|
|
prediction_type = st.radio( |
|
"Select prediction type", |
|
["Single Comment", "Multiple Comments"] |
|
) |
|
|
|
|
|
use_bert = False |
|
if st.session_state.bert_model is not None: |
|
use_bert = st.checkbox("Include sentiment analysis with BERT") |
|
|
|
|
|
if prediction_type == "Single Comment": |
|
comment = st.text_area("Enter a comment to classify:") |
|
|
|
if st.button("Classify Comment"): |
|
if comment: |
|
with st.spinner("Classifying comment..."): |
|
st.write("DEBUG: bert_model in session_state before use:", st.session_state.bert_model) |
|
sentiment_model = st.session_state.bert_model if use_bert and st.session_state.bert_model is not None else None |
|
result = moderate_comment(comment, model_to_use, sentiment_model) |
|
|
|
display_classification_result(result) |
|
else: |
|
st.warning("Please enter a comment to classify.") |
|
|
|
|
|
else: |
|
|
|
uploaded_file = st.file_uploader("Upload a CSV file with comments", type="csv") |
|
|
|
if uploaded_file is not None: |
|
try: |
|
|
|
pred_data = pd.read_csv(uploaded_file) |
|
|
|
|
|
st.subheader("Uploaded Data") |
|
st.write(pred_data.head()) |
|
|
|
|
|
text_columns = [col for col in pred_data.columns if pred_data[col].dtype == 'object'] |
|
|
|
if text_columns: |
|
text_column = st.selectbox("Select text column for prediction", text_columns) |
|
|
|
|
|
if st.button("Run Batch Prediction"): |
|
with st.spinner("Classifying comments..."): |
|
try: |
|
|
|
pred_data['processed_text'] = pred_data[text_column].apply(preprocess_text) |
|
|
|
|
|
sentiment_model = st.session_state.bert_model if use_bert else None |
|
predictions = extract_predictions(pred_data, text_column, model_to_use, |
|
sentiment_model) |
|
|
|
|
|
st.session_state.predictions = predictions |
|
|
|
|
|
st.subheader("Prediction Results") |
|
st.write(predictions.head()) |
|
|
|
|
|
st.subheader("Summary") |
|
toxic_count = predictions['is_toxic'].sum() |
|
total_count = len(predictions) |
|
toxic_percentage = (toxic_count / total_count) * 100 |
|
|
|
st.write(f"Total comments: {total_count}") |
|
st.write(f"Toxic comments: {toxic_count} ({toxic_percentage:.2f}%)") |
|
st.write( |
|
f"Non-toxic comments: {total_count - toxic_count} ({100 - toxic_percentage:.2f}%)") |
|
|
|
|
|
if not predictions.empty: |
|
csv = predictions.to_csv(index=False) |
|
b64 = base64.b64encode(csv.encode()).decode() |
|
href = f'<a href="data:file/csv;base64,{b64}" download="predictions.csv">Download Predictions CSV</a>' |
|
st.markdown(href, unsafe_allow_html=True) |
|
|
|
except Exception as e: |
|
st.error(f"Error during prediction: {e}") |
|
else: |
|
st.warning("No text columns found in the uploaded file.") |
|
|
|
except Exception as e: |
|
st.error(f"Error loading data: {e}") |
|
st.warning("Please upload a valid CSV file.") |
|
else: |
|
st.info("Please upload a CSV file with comments for batch prediction.") |
|
else: |
|
st.info("Please train a model before making predictions.") |
|
|
|
|
|
elif page == "Visualization": |
|
st.header("Visualization") |
|
|
|
|
|
if st.session_state.data is not None: |
|
|
|
st.subheader("Data Visualization") |
|
|
|
|
|
viz_type = st.selectbox( |
|
"Select visualization type", |
|
["Toxicity Distribution", "Comment Length Distribution", "Word Cloud", "Correlation Matrix"] |
|
) |
|
|
|
|
|
if viz_type == "Toxicity Distribution": |
|
|
|
label_columns = [col for col in st.session_state.data.columns if col in [ |
|
"toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate" |
|
]] |
|
|
|
if label_columns: |
|
st.write("Toxicity Distribution:") |
|
|
|
|
|
fig = plot_toxicity_distribution(st.session_state.data, label_columns) |
|
st.pyplot(fig) |
|
else: |
|
st.warning("No toxicity label columns found in the dataset.") |
|
|
|
|
|
elif viz_type == "Comment Length Distribution": |
|
|
|
text_columns = [col for col in st.session_state.data.columns if |
|
st.session_state.data[col].dtype == 'object'] |
|
|
|
if text_columns: |
|
text_column = st.selectbox("Select text column", text_columns) |
|
|
|
|
|
st.session_state.data['comment_length'] = st.session_state.data[text_column].apply( |
|
lambda x: len(str(x))) |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
sns.histplot(st.session_state.data['comment_length'], bins=50, kde=True, ax=ax) |
|
plt.xlabel('Comment Length') |
|
plt.ylabel('Frequency') |
|
plt.title('Comment Length Distribution') |
|
st.pyplot(fig) |
|
|
|
|
|
st.write("Comment Length Statistics:") |
|
st.write(st.session_state.data['comment_length'].describe()) |
|
else: |
|
st.warning("No text columns found in the dataset.") |
|
|
|
|
|
elif viz_type == "Word Cloud": |
|
|
|
if 'processed_text' in st.session_state.data.columns: |
|
try: |
|
from wordcloud import WordCloud |
|
|
|
|
|
st.write("Word Cloud Visualization:") |
|
|
|
|
|
all_text = ' '.join(st.session_state.data['processed_text'].tolist()) |
|
|
|
|
|
wordcloud = WordCloud( |
|
width=800, |
|
height=400, |
|
background_color='white', |
|
max_words=200, |
|
contour_width=3, |
|
contour_color='steelblue' |
|
).generate(all_text) |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
ax.imshow(wordcloud, interpolation='bilinear') |
|
ax.axis('off') |
|
plt.tight_layout() |
|
st.pyplot(fig) |
|
|
|
except ImportError: |
|
st.error("WordCloud package is not installed. Please install it to use this feature.") |
|
else: |
|
st.warning("Processed text not found. Please preprocess the data first.") |
|
|
|
|
|
elif viz_type == "Correlation Matrix": |
|
|
|
numerical_columns = [col for col in st.session_state.data.columns if |
|
st.session_state.data[col].dtype != 'object'] |
|
|
|
if len(numerical_columns) > 1: |
|
|
|
selected_columns = st.multiselect( |
|
"Select columns for correlation matrix", |
|
options=numerical_columns, |
|
default=[col for col in numerical_columns if col in [ |
|
"toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate" |
|
]] |
|
) |
|
|
|
if selected_columns and len(selected_columns) > 1: |
|
|
|
correlation = st.session_state.data[selected_columns].corr() |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 8)) |
|
sns.heatmap( |
|
correlation, |
|
annot=True, |
|
cmap='coolwarm', |
|
ax=ax, |
|
cbar=True, |
|
fmt='.2f', |
|
linewidths=0.5 |
|
) |
|
plt.title('Correlation Matrix') |
|
st.pyplot(fig) |
|
else: |
|
st.warning("Please select at least two columns for correlation matrix.") |
|
else: |
|
st.warning("Not enough numerical columns for correlation analysis.") |
|
|
|
|
|
|
|
|
|
if st.session_state.predictions is not None: |
|
st.subheader("Prediction Visualization") |
|
|
|
|
|
st.write("Distribution of Predictions:") |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
|
|
if 'toxicity_score' in st.session_state.predictions.columns: |
|
sns.histplot(st.session_state.predictions['toxicity_score'], bins=20, kde=True, ax=ax) |
|
plt.xlabel('Toxicity Score') |
|
plt.ylabel('Frequency') |
|
plt.title('Distribution of Toxicity Scores') |
|
st.pyplot(fig) |
|
|
|
|
|
st.write("Toxicity Threshold Analysis:") |
|
|
|
threshold = st.slider("Toxicity Threshold", 0.0, 1.0, 0.5, 0.05) |
|
|
|
|
|
st.session_state.predictions['is_toxic_at_threshold'] = st.session_state.predictions[ |
|
'toxicity_score'] > threshold |
|
|
|
toxic_at_threshold = st.session_state.predictions['is_toxic_at_threshold'].sum() |
|
total_predictions = len(st.session_state.predictions) |
|
|
|
st.write(f"Threshold: {threshold}") |
|
st.write( |
|
f"Toxic comments: {toxic_at_threshold} ({toxic_at_threshold / total_predictions * 100:.2f}%)") |
|
st.write( |
|
f"Non-toxic comments: {total_predictions - toxic_at_threshold} ({(total_predictions - toxic_at_threshold) / total_predictions * 100:.2f}%)") |
|
else: |
|
st.warning("Toxicity scores not found in predictions.") |
|
|
|
else: |
|
st.info("Please upload and preprocess data for visualization.") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |