ChiragPatankar commited on
Commit
c48986c
·
verified ·
1 Parent(s): 6f8594d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +1324 -0
app.py ADDED
@@ -0,0 +1,1324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import re
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
+ import nltk
8
+ from nltk.corpus import stopwords
9
+ from nltk.stem.snowball import SnowballStemmer
10
+ import pickle
11
+ from transformers import pipeline as hf_pipeline
12
+ from sklearn.utils.multiclass import type_of_target
13
+ import io
14
+ import base64
15
+ from sklearn.model_selection import train_test_split
16
+ from sklearn.feature_extraction.text import TfidfVectorizer
17
+ from sklearn.pipeline import Pipeline
18
+ from sklearn.multiclass import OneVsRestClassifier
19
+ from sklearn.linear_model import LogisticRegression
20
+ from sklearn.naive_bayes import MultinomialNB
21
+ from sklearn.metrics import roc_auc_score, accuracy_score, classification_report
22
+ from textblob import TextBlob
23
+ import warnings
24
+ from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
25
+ from sklearn.metrics import roc_curve
26
+
27
+ warnings.filterwarnings('ignore')
28
+
29
+ # Download required NLTK resources
30
+ try:
31
+ nltk.data.find('corpora/stopwords')
32
+ except LookupError:
33
+ nltk.download('stopwords')
34
+
35
+ # Initialize the stemmer
36
+ stemmer = SnowballStemmer('english')
37
+ stop_words_set = set(stopwords.words('english'))
38
+
39
+
40
+ # Text preprocessing functions
41
+ def remove_stopwords(text):
42
+ return " ".join([word for word in str(text).split() if word.lower() not in stop_words_set])
43
+
44
+
45
+ def train_lightweight_model(data, text_column, label_column):
46
+ """
47
+ Train a lightweight model for toxicity detection
48
+
49
+ Args:
50
+ data: DataFrame containing the data
51
+ text_column: Column name for text data
52
+ label_column: Column name for label data
53
+
54
+ Returns:
55
+ Trained model and vectorizer
56
+ """
57
+ from sklearn.feature_extraction.text import TfidfVectorizer
58
+ from sklearn.linear_model import LogisticRegression
59
+ from sklearn.pipeline import Pipeline
60
+
61
+ # Preprocess text
62
+ data['processed_text'] = data[text_column].apply(preprocess_text)
63
+
64
+ # Create a pipeline with TF-IDF and Logistic Regression
65
+ model = Pipeline([
66
+ ('tfidf', TfidfVectorizer(max_features=5000, ngram_range=(1, 2))),
67
+ ('clf', LogisticRegression(random_state=42, max_iter=1000))
68
+ ])
69
+
70
+ # Train the model
71
+ model.fit(data['processed_text'], data[label_column])
72
+
73
+ return model
74
+
75
+
76
+ def load_bert_model():
77
+ """
78
+ Load a pre-trained BERT model for sentiment analysis
79
+
80
+ Returns:
81
+ Loaded model
82
+ """
83
+ try:
84
+ # Load sentiment analysis model from HuggingFace
85
+ sentiment_analyzer = hf_pipeline("sentiment-analysis")
86
+ st.success("BERT model loaded successfully!")
87
+ return sentiment_analyzer
88
+ except Exception as e:
89
+ st.error(f"Error loading BERT model: {e}")
90
+ return None
91
+
92
+
93
+ def clean_text(text):
94
+ text = str(text).lower()
95
+ text = re.sub(r"what's", "what is ", text)
96
+ text = re.sub(r"\'s", " ", text)
97
+ text = re.sub(r"\'ve", " have ", text)
98
+ text = re.sub(r"can't", "can not ", text)
99
+ text = re.sub(r"n't", " not ", text)
100
+ text = re.sub(r"i'm", "i am ", text)
101
+ text = re.sub(r"\'re", " are ", text)
102
+ text = re.sub(r"\'d", " would ", text)
103
+ text = re.sub(r"\'ll", " will ", text)
104
+ text = re.sub(r"\'scuse", " excuse ", text)
105
+ text = re.sub(r'\W', ' ', text) # Remove non-word characters
106
+ text = re.sub(r'\s+', ' ', text).strip() # Remove extra spaces
107
+ return text
108
+
109
+
110
+ def stemming(sentence):
111
+ return " ".join([stemmer.stem(word) for word in str(sentence).split()])
112
+
113
+
114
+ def preprocess_text(text):
115
+ text = remove_stopwords(text)
116
+ text = clean_text(text)
117
+ text = stemming(text)
118
+ return text
119
+
120
+
121
+ # Function to get sentiment
122
+ def get_sentiment(text):
123
+ score = TextBlob(text).sentiment.polarity
124
+ if score > 0:
125
+ return "Positive", score
126
+ elif score < 0:
127
+ return "Negative", score
128
+ else:
129
+ return "Neutral", score
130
+
131
+
132
+ # Function to moderate text based on toxicity
133
+ def moderate_text(text, predictions, threshold_moderate=0.5, threshold_delete=0.8):
134
+ # If binary classification, only check the 'toxic' probability (index 1)
135
+ if len(predictions) == 2:
136
+ toxic_prob = predictions[1]
137
+ if toxic_prob >= threshold_delete:
138
+ return "*** COMMENT DELETED DUE TO HIGH TOXICITY ***", "delete"
139
+ elif toxic_prob >= threshold_moderate:
140
+ # List of potentially toxic words to censor
141
+ toxic_words = ["stupid", "idiot", "dumb", "hate", "sucks", "terrible",
142
+ "awful", "garbage", "trash", "pathetic", "ridiculous"]
143
+
144
+ words = text.split()
145
+ moderated_words = []
146
+
147
+ for word in words:
148
+ # Clean word for comparison
149
+ clean_word = re.sub(r'[^\w\s]', '', word.lower())
150
+
151
+ # Check if the word is in the toxic words list
152
+ if clean_word in toxic_words:
153
+ # Replace with a more neutral placeholder
154
+ moderated_words.append("[inappropriate]")
155
+ else:
156
+ moderated_words.append(word)
157
+
158
+ return " ".join(moderated_words), "moderate"
159
+ else:
160
+ return text, "keep"
161
+ else:
162
+ # Multi-label: check all classes
163
+ if any(pred >= threshold_delete for pred in predictions):
164
+ return "*** COMMENT DELETED DUE TO HIGH TOXICITY ***", "delete"
165
+ elif any(pred >= threshold_moderate for pred in predictions):
166
+ # List of potentially toxic words to censor
167
+ toxic_words = ["stupid", "idiot", "dumb", "hate", "sucks", "terrible",
168
+ "awful", "garbage", "trash", "pathetic", "ridiculous"]
169
+
170
+ words = text.split()
171
+ moderated_words = []
172
+
173
+ for word in words:
174
+ # Clean word for comparison
175
+ clean_word = re.sub(r'[^\w\s]', '', word.lower())
176
+
177
+ # Check if the word is in the toxic words list
178
+ if clean_word in toxic_words:
179
+ # Replace with a more neutral placeholder
180
+ moderated_words.append("[inappropriate]")
181
+ else:
182
+ moderated_words.append(word)
183
+
184
+ return " ".join(moderated_words), "moderate"
185
+ else:
186
+ return text, "keep"
187
+
188
+
189
+ # Function to train and save the model
190
+ def train_model(X_train, y_train, model_type='logistic_regression'):
191
+ st.write("Training model...")
192
+
193
+ # Ensure `y_train` has 6 columns
194
+ label_columns = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
195
+
196
+ # Create missing columns if they don't exist
197
+ for col in label_columns:
198
+ if col not in y_train.columns:
199
+ y_train[col] = 0
200
+
201
+ # Ensure columns are in the right order
202
+ y_train = y_train[label_columns]
203
+
204
+ if model_type == 'logistic_regression':
205
+ pipeline = Pipeline([
206
+ ('tfidf', TfidfVectorizer(stop_words='english', max_features=50000)),
207
+ ('clf', OneVsRestClassifier(LogisticRegression(max_iter=1000), n_jobs=-1))
208
+ ])
209
+ else: # Naive Bayes
210
+ pipeline = Pipeline([
211
+ ('tfidf', TfidfVectorizer(stop_words='english', max_features=50000)),
212
+ ('clf', OneVsRestClassifier(MultinomialNB(), n_jobs=-1))
213
+ ])
214
+
215
+ pipeline.fit(X_train, y_train)
216
+
217
+ return pipeline
218
+
219
+
220
+
221
+
222
+ def evaluate_model(pipeline, X_test, y_test):
223
+ """
224
+ Evaluates the given trained pipeline on test data.
225
+ Returns:
226
+ accuracy: Accuracy score
227
+ roc_auc: ROC AUC score
228
+ predictions: Predicted labels
229
+ pred_probs: Predicted probabilities
230
+ fpr: False Positive Rate array (for ROC curve)
231
+ tpr: True Positive Rate array (for ROC curve)
232
+ """
233
+
234
+ # Get predictions and prediction probabilities
235
+ predictions = pipeline.predict(X_test)
236
+ pred_probs = pipeline.predict_proba(X_test)
237
+
238
+ if isinstance(pred_probs, list) and len(pred_probs) == 1:
239
+ pred_probs = pred_probs[0] # Handle list with 1 array
240
+
241
+ # Ensure predictions format matches y_test
242
+ y_type = type_of_target(y_test)
243
+ pred_type = type_of_target(predictions)
244
+
245
+ if y_type != pred_type:
246
+ if y_type == "multilabel-indicator" and pred_type == "binary":
247
+ # Expand binary predictions to multilabel shape
248
+ predictions = np.array([[pred] * y_test.shape[1] for pred in predictions])
249
+ elif y_type == "binary" and pred_type == "multilabel-indicator":
250
+ # Collapse multilabel predictions to binary
251
+ predictions = predictions[:, 0]
252
+
253
+ # Calculate accuracy
254
+ accuracy = accuracy_score(y_test, predictions)
255
+
256
+ # Calculate ROC AUC
257
+ try:
258
+ if len(y_test.shape) > 1 and y_test.shape[1] > 1:
259
+ # Multi-label case
260
+ roc_auc_sum = 0
261
+ valid_labels = 0
262
+ for i in range(y_test.shape[1]):
263
+ try:
264
+ roc_auc_sum += roc_auc_score(y_test.iloc[:, i], pred_probs[:, i])
265
+ valid_labels += 1
266
+ except Exception:
267
+ continue
268
+ roc_auc = roc_auc_sum / valid_labels if valid_labels > 0 else 0.0
269
+ else:
270
+ # Binary case
271
+ roc_auc = roc_auc_score(y_test, pred_probs[:, 1] if pred_probs.ndim > 1 and pred_probs.shape[1] > 1 else pred_probs)
272
+ except Exception as e:
273
+ print(f"Warning: Could not compute ROC AUC - {e}")
274
+ roc_auc = 0.0
275
+
276
+ # Calculate FPR, TPR for ROC curve (only for binary classification)
277
+ try:
278
+ if len(y_test.shape) == 1 or (len(y_test.shape) > 1 and y_test.shape[1] == 1):
279
+ fpr, tpr, _ = roc_curve(
280
+ y_test,
281
+ pred_probs[:, 1] if pred_probs.ndim > 1 and pred_probs.shape[1] > 1 else pred_probs
282
+ )
283
+ else:
284
+ fpr, tpr = None, None # ROC Curve not available for multilabel
285
+ except Exception as e:
286
+ print(f"Warning: Could not compute ROC Curve - {e}")
287
+ fpr, tpr = None, None
288
+
289
+ return accuracy, roc_auc, predictions, pred_probs, fpr, tpr
290
+
291
+
292
+
293
+ # Function to create a download link for the trained model
294
+ def get_model_download_link(model, filename):
295
+ model_bytes = pickle.dumps(model)
296
+ b64 = base64.b64encode(model_bytes).decode()
297
+ href = f'<a href="data:file/pickle;base64,{b64}" download="{filename}">Download Trained Model</a>'
298
+ return href
299
+
300
+
301
+ # Function to plot toxicity distribution
302
+ def plot_toxicity_distribution(df, toxicity_columns):
303
+ fig, ax = plt.subplots(figsize=(10, 6))
304
+
305
+ x = df[toxicity_columns].sum()
306
+ sns.barplot(x=x.index, y=x.values, alpha=0.8, palette='viridis', ax=ax)
307
+
308
+ plt.title('Toxicity Distribution')
309
+ plt.ylabel('Count')
310
+ plt.xlabel('Toxicity Category')
311
+ plt.xticks(rotation=45)
312
+
313
+ return fig
314
+
315
+
316
+ # Function to provide sample data format
317
+ def show_sample_data_format():
318
+ st.subheader("Sample Data Format")
319
+
320
+ # Create sample dataframe
321
+ sample_data = {
322
+ 'comment_text': [
323
+ "This is a normal comment.",
324
+ "This is a toxic comment you idiot!",
325
+ "You're all worthless and should die.",
326
+ "I respectfully disagree with your point."
327
+ ],
328
+ 'toxic': [0, 1, 1, 0],
329
+ 'severe_toxic': [0, 0, 1, 0],
330
+ 'obscene': [0, 1, 0, 0],
331
+ 'threat': [0, 0, 1, 0],
332
+ 'insult': [0, 1, 1, 0],
333
+ 'identity_hate': [0, 0, 0, 0]
334
+ }
335
+
336
+ sample_df = pd.DataFrame(sample_data)
337
+ st.dataframe(sample_df)
338
+
339
+ # Create download link for sample data
340
+ csv = sample_df.to_csv(index=False)
341
+ b64 = base64.b64encode(csv.encode()).decode()
342
+ href = f'<a href="data:file/csv;base64,{b64}" download="sample_toxic_data.csv">Download Sample CSV</a>'
343
+ st.markdown(href, unsafe_allow_html=True)
344
+
345
+ st.info("""
346
+ Your CSV file should contain:
347
+ 1. A column with comment text
348
+ 2. One or more columns with binary values (0 or 1) for each toxicity category
349
+ """)
350
+
351
+
352
+ # Function to validate dataset
353
+ def validate_dataset(df, comment_column, toxicity_columns):
354
+ issues = []
355
+
356
+ # Check if comment column exists
357
+ if comment_column not in df.columns:
358
+ issues.append(f"Comment column '{comment_column}' not found in the dataset")
359
+
360
+ # Check if toxicity columns exist
361
+ missing_columns = [col for col in toxicity_columns if col not in df.columns]
362
+ if missing_columns:
363
+ issues.append(f"Missing toxicity columns: {', '.join(missing_columns)}")
364
+
365
+ # Check if values in toxicity columns are valid (0 or 1)
366
+ for col in toxicity_columns:
367
+ if col in df.columns:
368
+ # Check for non-numeric values
369
+ if not pd.api.types.is_numeric_dtype(df[col]):
370
+ issues.append(f"Column '{col}' contains non-numeric values")
371
+ else:
372
+ # Check for values other than 0 and 1
373
+ invalid_values = df[col].dropna().apply(lambda x: x not in [0, 1, 0.0, 1.0])
374
+ if invalid_values.any():
375
+ issues.append(f"Column '{col}' contains values other than 0 and 1")
376
+
377
+ # Check for empty data
378
+ if df.empty:
379
+ issues.append("Dataset is empty")
380
+ elif df[comment_column].isna().all():
381
+ issues.append("Comment column contains no data")
382
+
383
+ return issues
384
+
385
+
386
+ # Function to extract predictions from model output
387
+ def extract_predictions(predictions_proba, toxicity_categories):
388
+ """
389
+ Helper function to extract probabilities from model output,
390
+ handling different output formats.
391
+ """
392
+ # Debug information
393
+ if st.session_state.debug_mode:
394
+ st.write(f"Predictions type: {type(predictions_proba)}")
395
+ st.write(
396
+ f"Predictions shape/length: {np.shape(predictions_proba) if hasattr(predictions_proba, 'shape') else len(predictions_proba)}")
397
+
398
+ # Case 1: List of arrays with one element per toxicity category
399
+ if isinstance(predictions_proba, list) and len(predictions_proba) == len(toxicity_categories):
400
+ return [pred_array[:, 1][0] if pred_array.shape[1] > 1 else pred_array[0] for pred_array in predictions_proba]
401
+
402
+ # Case 2: List with a single array (common for OneVsRestClassifier)
403
+ elif isinstance(predictions_proba, list) and len(predictions_proba) == 1:
404
+ pred_array = predictions_proba[0]
405
+ # If it's a 2D array with number of columns equal to number of categories
406
+ if len(pred_array.shape) == 2 and pred_array.shape[1] == len(toxicity_categories):
407
+ return pred_array[0] # Return first row, which contains all probabilities
408
+ # If it's a 2D array with 2 columns per category (common binary classifier output)
409
+ elif len(pred_array.shape) == 2 and pred_array.shape[1] == 2:
410
+ return np.array([pred_array[0, 1]])
411
+
412
+ # Case 3: Direct numpy array
413
+ elif isinstance(predictions_proba, np.ndarray):
414
+ # If it's already the right shape
415
+ if len(predictions_proba.shape) == 2 and predictions_proba.shape[1] == len(toxicity_categories):
416
+ return predictions_proba[0]
417
+ # If it's a 2D array with two columns (binary classification)
418
+ elif len(predictions_proba.shape) == 2 and predictions_proba.shape[1] == 2:
419
+ # For binary classification, return the probability of positive class
420
+ return np.array([predictions_proba[0, 1]])
421
+
422
+ # If prediction format isn't recognized, return a repeated array of single probability
423
+ # This handles the case where we only have one prediction but need to repeat it
424
+ if isinstance(predictions_proba, list) and len(predictions_proba) == 1:
425
+ single_prob = predictions_proba[0]
426
+ if hasattr(single_prob, 'shape') and len(single_prob.shape) == 2 and single_prob.shape[1] == 2:
427
+ # Take positive class probability and repeat for all categories
428
+ return np.full(len(toxicity_categories), single_prob[0, 1])
429
+
430
+ # Last resort fallback
431
+ st.warning(f"Unexpected prediction format. Creating default predictions.")
432
+ return np.zeros(len(toxicity_categories))
433
+
434
+
435
+ def display_classification_result(result):
436
+ st.subheader("Classification Result")
437
+
438
+ # Show original and moderated text side by side
439
+ col1, col2 = st.columns(2)
440
+ with col1:
441
+ st.markdown("**Original Text**")
442
+ st.code(result["original_text"], language="text")
443
+ with col2:
444
+ st.markdown("**Moderated Text**")
445
+ st.code(result["moderated_text"], language="text")
446
+
447
+ # Show action with color
448
+ action = result["action"]
449
+ if action == "keep":
450
+ st.success("✅ This comment is allowed (Non-toxic).")
451
+ elif action == "moderate":
452
+ st.warning("⚠️ This comment is moderated (Potentially toxic).")
453
+ elif action == "delete":
454
+ st.error("🚫 This comment is deleted (Highly toxic).")
455
+
456
+ # Show toxicity scores
457
+ st.markdown("**Toxicity Scores:**")
458
+ score_cols = st.columns(len(result["toxicity_scores"]))
459
+ for i, (label, score) in enumerate(result["toxicity_scores"].items()):
460
+ score_cols[i].metric(label.capitalize(), f"{score:.2%}")
461
+
462
+ # Show sentiment if available
463
+ if "sentiment" in result:
464
+ st.markdown("**Sentiment Analysis:**")
465
+ st.info(f"{result['sentiment']['label']} (score: {result['sentiment']['score']:.2%})")
466
+
467
+
468
+ def moderate_comment(comment, model, sentiment_model=None):
469
+ """
470
+ Moderate a single comment using the trained model and optionally BERT sentiment analysis.
471
+
472
+ Args:
473
+ comment: The comment text to moderate
474
+ model: The trained model to use for toxicity detection
475
+ sentiment_model: Optional BERT model for sentiment analysis
476
+
477
+ Returns:
478
+ Dictionary containing moderation results
479
+ """
480
+ # Preprocess the comment
481
+ processed_text = preprocess_text(comment)
482
+
483
+ # Get model predictions
484
+ predictions = model.predict_proba([processed_text])[0]
485
+
486
+ # Get sentiment if BERT model is available
487
+ sentiment = None
488
+ if sentiment_model:
489
+ sentiment = sentiment_model(comment)[0]
490
+
491
+ # Moderate the text
492
+ moderated_text, action = moderate_text(comment, predictions)
493
+
494
+ # Prepare result
495
+ result = {
496
+ "original_text": comment,
497
+ "moderated_text": moderated_text,
498
+ "action": action,
499
+ "toxicity_scores": {}
500
+ }
501
+
502
+ # Handle both binary and multi-class predictions
503
+ if len(predictions) == 2: # Binary classification
504
+ result["toxicity_scores"] = {
505
+ "toxic": float(predictions[1]), # Probability of positive class
506
+ "non_toxic": float(predictions[0]) # Probability of negative class
507
+ }
508
+ else: # Multi-class classification
509
+ # Define the toxicity categories
510
+ categories = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
511
+
512
+ # Add scores for each category that exists in the predictions
513
+ for i, category in enumerate(categories):
514
+ if i < len(predictions):
515
+ result["toxicity_scores"][category] = float(predictions[i])
516
+
517
+ if sentiment:
518
+ result["sentiment"] = {
519
+ "label": sentiment["label"],
520
+ "score": float(sentiment["score"])
521
+ }
522
+
523
+ return result
524
+
525
+
526
+ # --- Bias Detection Module ---
527
+ def detect_subgroup(text):
528
+ gender_keywords = ["he", "she", "him", "her", "man", "woman", "boy", "girl", "male", "female"]
529
+ ethnicity_keywords = [
530
+ "asian", "black", "white", "hispanic", "latino", "indian", "african", "european", "arab", "jewish", "muslim", "christian"
531
+ ]
532
+ text_lower = text.lower()
533
+ subgroups = set()
534
+ if any(word in text_lower for word in gender_keywords):
535
+ subgroups.add("gender")
536
+ if any(word in text_lower for word in ethnicity_keywords):
537
+ subgroups.add("ethnicity")
538
+ return list(subgroups)
539
+
540
+ def bias_report(X, y_true, y_pred, text_column_name):
541
+ # X: DataFrame with text column
542
+ # y_pred: predicted labels (same shape as y_true)
543
+ # y_true: true labels (same shape as y_pred)
544
+ # text_column_name: name of the text column
545
+ results = []
546
+ for idx, row in X.iterrows():
547
+ subgroups = detect_subgroup(row[text_column_name])
548
+ if subgroups:
549
+ for subgroup in subgroups:
550
+ results.append({
551
+ "subgroup": subgroup,
552
+ "is_toxic": int(y_pred[idx].sum() > 0) if len(y_pred.shape) > 1 else int(y_pred[idx] > 0)
553
+ })
554
+ if not results:
555
+ return "No sensitive subgroups detected in the evaluation set."
556
+ df = pd.DataFrame(results)
557
+ report = ""
558
+ for subgroup in df["subgroup"].unique():
559
+ total = (df["subgroup"] == subgroup).sum()
560
+ toxic = df[(df["subgroup"] == subgroup) & (df["is_toxic"] == 1)].shape[0]
561
+ rate = toxic / total if total > 0 else 0
562
+ report += f"- **{subgroup.capitalize()}**: {toxic}/{total} ({rate:.1%}) flagged as toxic\n"
563
+ return report
564
+
565
+
566
+ # Streamlit app
567
+ def main():
568
+ st.set_page_config(
569
+ page_title="Toxic Comment Classifier",
570
+ page_icon="🧊",
571
+ layout="wide",
572
+ initial_sidebar_state="expanded",
573
+ )
574
+
575
+ col1, col2 = st.columns([1, 4]) # Adjust the ratio as needed
576
+
577
+ with col1:
578
+ st.image("logo.jpeg", width=100) # Smaller width fits better
579
+
580
+ with col2:
581
+ st.title("Toxic Comment Classifier")
582
+
583
+
584
+ # Initialize session state variables if they don't exist
585
+ if 'data' not in st.session_state:
586
+ st.session_state.data = None
587
+ if 'model' not in st.session_state:
588
+ st.session_state.model = None
589
+ if 'vectorizer' not in st.session_state:
590
+ st.session_state.vectorizer = None
591
+ if 'predictions' not in st.session_state:
592
+ st.session_state.predictions = None
593
+ if 'lightweight_model' not in st.session_state:
594
+ st.session_state.lightweight_model = None
595
+ if 'bert_model' not in st.session_state:
596
+ st.session_state.bert_model = None
597
+
598
+ # Create sidebar navigation
599
+ st.sidebar.title("Navigation")
600
+ page = st.sidebar.radio(
601
+ "Select a page",
602
+ ["Home", "Data Preprocessing", "Model Training", "Model Evaluation", "Prediction", "Visualization"]
603
+ )
604
+
605
+ # Home page
606
+ if page == "Home":
607
+
608
+ st.header("Home")
609
+ st.write("""
610
+ Welcome to the Toxic Comment Classifier application. This tool helps you to:
611
+ 1. Upload and preprocess data
612
+ 2. Train a machine learning model to detect toxic comments
613
+ 3. Evaluate model performance
614
+ 4. Make predictions on new data
615
+ 5. Visualize results
616
+
617
+ Please use the sidebar navigation to get started.
618
+ """)
619
+
620
+ # Add option to load BERT sentiment model
621
+ if st.sidebar.checkbox("Use BERT for Sentiment Analysis"):
622
+ st.subheader("BERT-Based Sentiment Analysis")
623
+ st.write("This option uses a pre-trained BERT model for advanced sentiment analysis.")
624
+
625
+ if st.button("Load BERT Model"):
626
+ with st.spinner("Loading BERT model..."):
627
+ st.session_state.bert_model = load_bert_model()
628
+ st.write("DEBUG: bert_model in session_state after loading:", st.session_state.bert_model)
629
+
630
+ # Sample data section
631
+ st.subheader("Sample Data Format")
632
+ show_sample_data_format()
633
+
634
+ # Single comment moderation
635
+ st.subheader("Try Comment Moderation")
636
+ comment = st.text_area("Enter a comment to moderate:")
637
+
638
+ col1, col2 = st.columns(2)
639
+ with col1:
640
+ use_default_model = st.checkbox("Use built-in model for demo", value=True)
641
+
642
+ with col2:
643
+ use_bert = st.checkbox("Use BERT model for sentiment (if loaded)", value=False)
644
+
645
+ if st.button("Moderate Comment"):
646
+ if comment:
647
+ with st.spinner("Analyzing comment..."):
648
+ st.write("DEBUG: bert_model in session_state before use:", st.session_state.bert_model)
649
+ sentiment_model = st.session_state.bert_model if use_bert and st.session_state.bert_model is not None else None
650
+
651
+ if use_default_model or st.session_state.model or st.session_state.lightweight_model:
652
+ model_to_use = None
653
+ if st.session_state.model:
654
+ model_to_use = st.session_state.model
655
+ elif st.session_state.lightweight_model:
656
+ model_to_use = st.session_state.lightweight_model
657
+
658
+ result = moderate_comment(comment, model_to_use, sentiment_model)
659
+
660
+ display_classification_result(result)
661
+ else:
662
+ st.error("No model available. Please train a model first or enable the demo model.")
663
+ else:
664
+ st.warning("Please enter a comment to moderate.")
665
+
666
+ # Data Preprocessing page
667
+ elif page == "Data Preprocessing":
668
+ st.header("Data Preprocessing")
669
+
670
+ # File upload
671
+ st.subheader("Upload Dataset")
672
+ uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
673
+
674
+ if uploaded_file is not None:
675
+ try:
676
+ # Load data
677
+ data = pd.read_csv(uploaded_file)
678
+
679
+ # Display raw data
680
+ st.subheader("Raw Data")
681
+ st.write(data.head())
682
+
683
+ # Validate the dataset
684
+ validation_result = validate_dataset(data, 'comment_text', ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'])
685
+
686
+ if not validation_result: # Empty list means no issues
687
+ st.success("Dataset is valid!")
688
+
689
+ # Store data in session state
690
+ st.session_state.data = data
691
+
692
+ # Data cleaning
693
+ st.subheader("Data Cleaning")
694
+ st.write("Select columns to include in the analysis:")
695
+
696
+ # Get all columns
697
+ all_columns = data.columns.tolist()
698
+ default_selected = ["comment_text", "toxic", "severe_toxic", "obscene", "threat", "insult",
699
+ "identity_hate"]
700
+ default_selected = [col for col in default_selected if col in all_columns]
701
+
702
+ selected_columns = st.multiselect(
703
+ "Select columns",
704
+ options=all_columns,
705
+ default=default_selected
706
+ )
707
+
708
+ if selected_columns:
709
+ # Filter data by selected columns
710
+ filtered_data = data[selected_columns]
711
+
712
+ # Display cleaned data
713
+ st.subheader("Filtered Data")
714
+ st.write(filtered_data.head())
715
+
716
+ # Display data statistics
717
+ st.subheader("Data Statistics")
718
+ st.write(filtered_data.describe())
719
+
720
+ # Check for missing values
721
+ st.subheader("Missing Values")
722
+ missing_values = filtered_data.isnull().sum()
723
+ st.write(missing_values)
724
+
725
+ # Handle missing values if any
726
+ if missing_values.sum() > 0:
727
+ st.warning("There are missing values in the dataset.")
728
+
729
+ if st.button("Handle Missing Values"):
730
+ # Fill missing text with empty string
731
+ text_columns = [col for col in selected_columns if filtered_data[col].dtype == 'object']
732
+ for col in text_columns:
733
+ filtered_data[col] = filtered_data[col].fillna("")
734
+
735
+ # Fill missing numerical values with 0
736
+ numerical_columns = [col for col in selected_columns if
737
+ filtered_data[col].dtype != 'object']
738
+ for col in numerical_columns:
739
+ filtered_data[col] = filtered_data[col].fillna(0)
740
+
741
+ st.success("Missing values handled!")
742
+ st.write(filtered_data.isnull().sum())
743
+
744
+ # Text preprocessing
745
+ st.subheader("Text Preprocessing")
746
+
747
+ # Select text column
748
+ text_columns = [col for col in selected_columns if filtered_data[col].dtype == 'object']
749
+
750
+ if text_columns:
751
+ text_column = st.selectbox("Select text column for preprocessing", text_columns)
752
+
753
+ # Show sample of original text
754
+ st.write("Sample original text:")
755
+ sample_texts = filtered_data[text_column].head(3).tolist()
756
+ for i, text in enumerate(sample_texts):
757
+ st.text(f"Text {i + 1}: {text[:200]}...")
758
+
759
+ # Preprocess text
760
+ if st.button("Preprocess Text"):
761
+ with st.spinner("Preprocessing text..."):
762
+ filtered_data['processed_text'] = filtered_data[text_column].apply(preprocess_text)
763
+
764
+ # Show sample of preprocessed text
765
+ st.write("Sample preprocessed text:")
766
+ sample_preprocessed = filtered_data['processed_text'].head(3).tolist()
767
+ for i, text in enumerate(sample_preprocessed):
768
+ st.text(f"Processed Text {i + 1}: {text[:200]}...")
769
+
770
+ # Store preprocessed data
771
+ st.session_state.data = filtered_data
772
+ st.success("Text preprocessing completed!")
773
+ else:
774
+ st.warning("No text columns found in the selected columns.")
775
+ else:
776
+ st.warning("Please select at least one column.")
777
+ else:
778
+ st.error(f"Dataset validation failed: {validation_result['reason']}")
779
+ st.warning("Please upload a valid dataset.")
780
+
781
+ except Exception as e:
782
+ st.error(f"Error loading data: {e}")
783
+ st.warning("Please upload a valid CSV file.")
784
+ else:
785
+ st.info("Please upload a CSV file to begin preprocessing.")
786
+
787
+ # Model Training page
788
+ elif page == "Model Training":
789
+ st.header("Model Training")
790
+
791
+ # Check if data is available
792
+ if st.session_state.data is not None:
793
+ # Display data info
794
+ st.subheader("Dataset Information")
795
+ st.write(f"Number of samples: {len(st.session_state.data)}")
796
+
797
+ if 'processed_text' in st.session_state.data.columns:
798
+ st.write("Text preprocessing: Done")
799
+ else:
800
+ st.warning("Text preprocessing is not done. Please preprocess the data first.")
801
+
802
+ # Model training options
803
+ st.subheader("Training Options")
804
+
805
+ # Select target column
806
+ numerical_columns = [col for col in st.session_state.data.columns if
807
+ st.session_state.data[col].dtype != 'object']
808
+
809
+ if numerical_columns:
810
+ target_column = st.selectbox("Select target column", numerical_columns)
811
+
812
+ # Set training parameters
813
+ st.write("Training Parameters:")
814
+ col1, col2 = st.columns(2)
815
+
816
+ with col1:
817
+ test_size = st.slider("Test Size", 0.1, 0.5, 0.2, 0.05)
818
+
819
+ with col2:
820
+ random_state = st.number_input("Random State", 0, 100, 42, 1)
821
+
822
+ # Model selection
823
+ model_type = st.radio(
824
+ "Select model type",
825
+ ["Standard Model", "Lightweight Model"]
826
+ )
827
+
828
+ # Train model button
829
+ if st.button("Train Model"):
830
+ with st.spinner("Training model..."):
831
+ # Check if processed text is available
832
+ if 'processed_text' in st.session_state.data.columns:
833
+ try:
834
+ if model_type == "Standard Model":
835
+ # Train standard model
836
+ X_train = st.session_state.data['processed_text']
837
+ y_train = st.session_state.data[[target_column]]
838
+ model = train_model(X_train, y_train, 'logistic_regression')
839
+
840
+ # Store model in session state
841
+ st.session_state.model = model
842
+ st.session_state.vectorizer = None # Vectorizer is part of the pipeline
843
+
844
+ st.success("Model training completed!")
845
+ else:
846
+ # Train lightweight model
847
+ lightweight_model = train_lightweight_model(
848
+ st.session_state.data,
849
+ 'processed_text',
850
+ target_column
851
+ )
852
+
853
+ # Store lightweight model in session state
854
+ st.session_state.lightweight_model = lightweight_model
855
+
856
+ st.success("Lightweight model training completed!")
857
+
858
+ except Exception as e:
859
+ st.error(f"Error training model: {e}")
860
+ else:
861
+ st.error("Processed text not found. Please preprocess the data first.")
862
+ else:
863
+ st.warning("No numerical columns found in the dataset. Please ensure you have target columns.")
864
+ else:
865
+ st.info("Please upload and preprocess data before training a model.")
866
+
867
+ # Model Evaluation page
868
+ elif page == "Model Evaluation":
869
+ st.header("Model Evaluation")
870
+
871
+ # Check if model is available
872
+ model_available = st.session_state.model is not None or st.session_state.lightweight_model is not None
873
+
874
+ if model_available:
875
+ # Display model info
876
+ st.subheader("Model Information")
877
+ if st.session_state.model is not None:
878
+ st.write("Standard model is trained and ready.")
879
+ if st.session_state.lightweight_model is not None:
880
+ st.write("Lightweight model is trained and ready.")
881
+
882
+ # Select model to evaluate
883
+ model_choice = None
884
+ if st.session_state.model is not None and st.session_state.lightweight_model is not None:
885
+ model_choice = st.radio(
886
+ "Select model to evaluate",
887
+ ["Standard Model", "Lightweight Model"]
888
+ )
889
+
890
+ # Model evaluation
891
+ st.subheader("Evaluation Options")
892
+
893
+ # Set evaluation parameters
894
+ st.write("Evaluation Parameters:")
895
+ col1, col2 = st.columns(2)
896
+
897
+ with col1:
898
+ test_size = st.slider("Test Size (Evaluation)", 0.1, 0.5, 0.2, 0.05)
899
+
900
+ with col2:
901
+ random_state = st.number_input("Random State (Evaluation)", 0, 100, 42, 1)
902
+
903
+ # Target column selection
904
+ if st.session_state.data is not None:
905
+ numerical_columns = [col for col in st.session_state.data.columns if
906
+ st.session_state.data[col].dtype != 'object']
907
+
908
+ if numerical_columns:
909
+ target_column = st.selectbox("Select target column for evaluation", numerical_columns)
910
+
911
+ # Evaluate model button
912
+ if st.button("Evaluate Model"):
913
+ with st.spinner("Evaluating model..."):
914
+ try:
915
+ # Determine which model to evaluate
916
+ model_to_evaluate = None
917
+ if model_choice == "Lightweight Model" or (
918
+ model_choice is None and st.session_state.model is None):
919
+ model_to_evaluate = st.session_state.lightweight_model
920
+ else:
921
+ model_to_evaluate = st.session_state.model
922
+
923
+ # 1️⃣ Evaluate model
924
+ X_test = st.session_state.data['processed_text']
925
+ y_test = st.session_state.data[[target_column]]
926
+
927
+ accuracy, roc_auc, predictions, pred_probs, fpr, tpr = evaluate_model(model_to_evaluate,
928
+ X_test, y_test)
929
+
930
+ # 2️⃣ Calculate additional metrics
931
+ precision = precision_score(y_test, predictions, average='weighted', zero_division=0)
932
+ recall = recall_score(y_test, predictions, average='weighted', zero_division=0)
933
+ f1 = f1_score(y_test, predictions, average='weighted', zero_division=0)
934
+ conf_matrix = confusion_matrix(y_test, predictions)
935
+ classification_rep = classification_report(y_test, predictions, zero_division=0)
936
+
937
+ # 3️⃣ Display evaluation metrics
938
+ st.subheader("Evaluation Results")
939
+ metrics_df = pd.DataFrame({
940
+ 'Metric': ['Accuracy', 'Precision', 'Recall', 'F1 Score', 'ROC AUC'],
941
+ 'Value': [accuracy, precision, recall, f1, roc_auc]
942
+ })
943
+ st.table(metrics_df)
944
+
945
+ # 4️⃣ Confusion Matrix
946
+ st.subheader("Confusion Matrix")
947
+ fig, ax = plt.subplots(figsize=(8, 6))
948
+ sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', ax=ax, cbar=False,
949
+ annot_kws={"size": 16})
950
+ plt.xlabel('Predicted')
951
+ plt.ylabel('Actual')
952
+ plt.title('Confusion Matrix')
953
+ st.pyplot(fig)
954
+
955
+ # 5️⃣ ROC Curve
956
+ st.subheader("ROC Curve")
957
+ if fpr is not None and tpr is not None:
958
+ fig, ax = plt.subplots(figsize=(8, 6))
959
+ ax.plot(fpr, tpr, label=f'ROC Curve (AUC = {roc_auc:.2f})')
960
+ ax.plot([0, 1], [0, 1], 'k--')
961
+ ax.set_xlabel('False Positive Rate')
962
+ ax.set_ylabel('True Positive Rate')
963
+ ax.set_title('ROC Curve')
964
+ ax.legend(loc="lower right")
965
+ st.pyplot(fig)
966
+ else:
967
+ st.info("ROC curve is not available for multi-label classification.")
968
+
969
+ # 6️⃣ Classification Report
970
+ st.subheader("Classification Report")
971
+ st.text(classification_rep)
972
+
973
+ except Exception as e:
974
+ st.error(f"Error evaluating model: {e}")
975
+
976
+ # 6️⃣ Classification Report
977
+ st.subheader("Classification Report")
978
+ st.text(classification_rep)
979
+
980
+ # ✅ Always show bias detection report if possible
981
+ # Bias Detection Report
982
+ st.subheader("Bias Detection Report")
983
+ if 'comment_text' in st.session_state.data.columns:
984
+ bias_summary = bias_report(
985
+ st.session_state.data[["comment_text"]].reset_index(drop=True),
986
+ y_test.reset_index(drop=True),
987
+ predictions,
988
+ "comment_text" # Add the text_column_name parameter
989
+ )
990
+ st.markdown(bias_summary)
991
+ else:
992
+ st.info("No comment_text column found for bias analysis.")
993
+
994
+
995
+
996
+ except Exception as e:
997
+ st.error(f"Error evaluating model: {e}")
998
+ else:
999
+ st.warning("No numerical columns found in the dataset. Please ensure you have target columns.")
1000
+ else:
1001
+ st.warning("Dataset not available. Please upload and preprocess data first.")
1002
+
1003
+ # Model download
1004
+ st.subheader("Model Download")
1005
+
1006
+ # Create download button
1007
+ model_to_download = None
1008
+ if model_choice == "Lightweight Model" or (model_choice is None and st.session_state.model is None):
1009
+ model_to_download = st.session_state.lightweight_model
1010
+ else:
1011
+ model_to_download = st.session_state.model
1012
+
1013
+ if model_to_download is not None:
1014
+ # Determine the appropriate filename based on model type
1015
+ filename = "lightweight_model.pkl" if model_choice == "Lightweight Model" or (model_choice is None and st.session_state.model is None) else "standard_model.pkl"
1016
+ download_link = get_model_download_link(model_to_download, filename)
1017
+ st.markdown(download_link, unsafe_allow_html=True)
1018
+ else:
1019
+ st.info("Please train a model before evaluation.")
1020
+
1021
+ # Prediction page
1022
+ elif page == "Prediction":
1023
+ st.header("Prediction")
1024
+
1025
+ # Check if model is available
1026
+ model_available = st.session_state.model is not None or st.session_state.lightweight_model is not None
1027
+
1028
+ if model_available:
1029
+ # Display model info
1030
+ st.subheader("Model Information")
1031
+ if st.session_state.model is not None:
1032
+ st.write("Standard model is trained and ready.")
1033
+ if st.session_state.lightweight_model is not None:
1034
+ st.write("Lightweight model is trained and ready.")
1035
+
1036
+ # Select model to use
1037
+ model_choice = None
1038
+ if st.session_state.model is not None and st.session_state.lightweight_model is not None:
1039
+ model_choice = st.radio(
1040
+ "Select model for prediction",
1041
+ ["Standard Model", "Lightweight Model"]
1042
+ )
1043
+
1044
+ # Determine which model to use
1045
+ model_to_use = None
1046
+ if model_choice == "Lightweight Model" or (model_choice is None and st.session_state.model is None):
1047
+ model_to_use = st.session_state.lightweight_model
1048
+ else:
1049
+ model_to_use = st.session_state.model
1050
+
1051
+ # Prediction options
1052
+ st.subheader("Make Predictions")
1053
+
1054
+ prediction_type = st.radio(
1055
+ "Select prediction type",
1056
+ ["Single Comment", "Multiple Comments"]
1057
+ )
1058
+
1059
+ # Option to use BERT model
1060
+ use_bert = False
1061
+ if st.session_state.bert_model is not None:
1062
+ use_bert = st.checkbox("Include sentiment analysis with BERT")
1063
+
1064
+ # Single comment prediction
1065
+ if prediction_type == "Single Comment":
1066
+ comment = st.text_area("Enter a comment to classify:")
1067
+
1068
+ if st.button("Classify Comment"):
1069
+ if comment:
1070
+ with st.spinner("Classifying comment..."):
1071
+ st.write("DEBUG: bert_model in session_state before use:", st.session_state.bert_model)
1072
+ sentiment_model = st.session_state.bert_model if use_bert and st.session_state.bert_model is not None else None
1073
+ result = moderate_comment(comment, model_to_use, sentiment_model)
1074
+
1075
+ display_classification_result(result)
1076
+ else:
1077
+ st.warning("Please enter a comment to classify.")
1078
+
1079
+ # Multiple comments prediction
1080
+ else:
1081
+ # File upload for prediction
1082
+ uploaded_file = st.file_uploader("Upload a CSV file with comments", type="csv")
1083
+
1084
+ if uploaded_file is not None:
1085
+ try:
1086
+ # Load data
1087
+ pred_data = pd.read_csv(uploaded_file)
1088
+
1089
+ # Display data
1090
+ st.subheader("Uploaded Data")
1091
+ st.write(pred_data.head())
1092
+
1093
+ # Select text column
1094
+ text_columns = [col for col in pred_data.columns if pred_data[col].dtype == 'object']
1095
+
1096
+ if text_columns:
1097
+ text_column = st.selectbox("Select text column for prediction", text_columns)
1098
+
1099
+ # Batch prediction button
1100
+ if st.button("Run Batch Prediction"):
1101
+ with st.spinner("Classifying comments..."):
1102
+ try:
1103
+ # Preprocess text
1104
+ pred_data['processed_text'] = pred_data[text_column].apply(preprocess_text)
1105
+
1106
+ # Run prediction
1107
+ sentiment_model = st.session_state.bert_model if use_bert else None
1108
+ predictions = extract_predictions(pred_data, text_column, model_to_use,
1109
+ sentiment_model)
1110
+
1111
+ # Store predictions
1112
+ st.session_state.predictions = predictions
1113
+
1114
+ # Display results
1115
+ st.subheader("Prediction Results")
1116
+ st.write(predictions.head())
1117
+
1118
+ # Summary
1119
+ st.subheader("Summary")
1120
+ toxic_count = predictions['is_toxic'].sum()
1121
+ total_count = len(predictions)
1122
+ toxic_percentage = (toxic_count / total_count) * 100
1123
+
1124
+ st.write(f"Total comments: {total_count}")
1125
+ st.write(f"Toxic comments: {toxic_count} ({toxic_percentage:.2f}%)")
1126
+ st.write(
1127
+ f"Non-toxic comments: {total_count - toxic_count} ({100 - toxic_percentage:.2f}%)")
1128
+
1129
+ # Download predictions
1130
+ if not predictions.empty:
1131
+ csv = predictions.to_csv(index=False)
1132
+ b64 = base64.b64encode(csv.encode()).decode()
1133
+ href = f'<a href="data:file/csv;base64,{b64}" download="predictions.csv">Download Predictions CSV</a>'
1134
+ st.markdown(href, unsafe_allow_html=True)
1135
+
1136
+ except Exception as e:
1137
+ st.error(f"Error during prediction: {e}")
1138
+ else:
1139
+ st.warning("No text columns found in the uploaded file.")
1140
+
1141
+ except Exception as e:
1142
+ st.error(f"Error loading data: {e}")
1143
+ st.warning("Please upload a valid CSV file.")
1144
+ else:
1145
+ st.info("Please upload a CSV file with comments for batch prediction.")
1146
+ else:
1147
+ st.info("Please train a model before making predictions.")
1148
+
1149
+ # Visualization page
1150
+ elif page == "Visualization":
1151
+ st.header("Visualization")
1152
+
1153
+ # Check if data is available
1154
+ if st.session_state.data is not None:
1155
+ # Data visualization
1156
+ st.subheader("Data Visualization")
1157
+
1158
+ # Select visualization type
1159
+ viz_type = st.selectbox(
1160
+ "Select visualization type",
1161
+ ["Toxicity Distribution", "Comment Length Distribution", "Word Cloud", "Correlation Matrix"]
1162
+ )
1163
+
1164
+ # Toxicity Distribution
1165
+ if viz_type == "Toxicity Distribution":
1166
+ # Check if there are label columns
1167
+ label_columns = [col for col in st.session_state.data.columns if col in [
1168
+ "toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"
1169
+ ]]
1170
+
1171
+ if label_columns:
1172
+ st.write("Toxicity Distribution:")
1173
+
1174
+ # Plot toxicity distribution
1175
+ fig = plot_toxicity_distribution(st.session_state.data, label_columns)
1176
+ st.pyplot(fig)
1177
+ else:
1178
+ st.warning("No toxicity label columns found in the dataset.")
1179
+
1180
+ # Comment Length Distribution
1181
+ elif viz_type == "Comment Length Distribution":
1182
+ # Check if there is text column
1183
+ text_columns = [col for col in st.session_state.data.columns if
1184
+ st.session_state.data[col].dtype == 'object']
1185
+
1186
+ if text_columns:
1187
+ text_column = st.selectbox("Select text column", text_columns)
1188
+
1189
+ # Calculate comment lengths
1190
+ st.session_state.data['comment_length'] = st.session_state.data[text_column].apply(
1191
+ lambda x: len(str(x)))
1192
+
1193
+ # Plot comment length distribution
1194
+ fig, ax = plt.subplots(figsize=(10, 6))
1195
+ sns.histplot(st.session_state.data['comment_length'], bins=50, kde=True, ax=ax)
1196
+ plt.xlabel('Comment Length')
1197
+ plt.ylabel('Frequency')
1198
+ plt.title('Comment Length Distribution')
1199
+ st.pyplot(fig)
1200
+
1201
+ # Statistics
1202
+ st.write("Comment Length Statistics:")
1203
+ st.write(st.session_state.data['comment_length'].describe())
1204
+ else:
1205
+ st.warning("No text columns found in the dataset.")
1206
+
1207
+ # Word Cloud
1208
+ elif viz_type == "Word Cloud":
1209
+ # Check if there is processed text
1210
+ if 'processed_text' in st.session_state.data.columns:
1211
+ try:
1212
+ from wordcloud import WordCloud
1213
+
1214
+ # Create word cloud
1215
+ st.write("Word Cloud Visualization:")
1216
+
1217
+ # Combine all processed text
1218
+ all_text = ' '.join(st.session_state.data['processed_text'].tolist())
1219
+
1220
+ # Generate word cloud
1221
+ wordcloud = WordCloud(
1222
+ width=800,
1223
+ height=400,
1224
+ background_color='white',
1225
+ max_words=200,
1226
+ contour_width=3,
1227
+ contour_color='steelblue'
1228
+ ).generate(all_text)
1229
+
1230
+ # Display word cloud
1231
+ fig, ax = plt.subplots(figsize=(10, 6))
1232
+ ax.imshow(wordcloud, interpolation='bilinear')
1233
+ ax.axis('off')
1234
+ plt.tight_layout()
1235
+ st.pyplot(fig)
1236
+
1237
+ except ImportError:
1238
+ st.error("WordCloud package is not installed. Please install it to use this feature.")
1239
+ else:
1240
+ st.warning("Processed text not found. Please preprocess the data first.")
1241
+
1242
+ # Correlation Matrix
1243
+ elif viz_type == "Correlation Matrix":
1244
+ # Get numerical columns
1245
+ numerical_columns = [col for col in st.session_state.data.columns if
1246
+ st.session_state.data[col].dtype != 'object']
1247
+
1248
+ if len(numerical_columns) > 1:
1249
+ # Select columns for correlation
1250
+ selected_columns = st.multiselect(
1251
+ "Select columns for correlation matrix",
1252
+ options=numerical_columns,
1253
+ default=[col for col in numerical_columns if col in [
1254
+ "toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"
1255
+ ]]
1256
+ )
1257
+
1258
+ if selected_columns and len(selected_columns) > 1:
1259
+ # Calculate correlation
1260
+ correlation = st.session_state.data[selected_columns].corr()
1261
+
1262
+ # Plot correlation matrix
1263
+ fig, ax = plt.subplots(figsize=(10, 8))
1264
+ sns.heatmap(
1265
+ correlation,
1266
+ annot=True,
1267
+ cmap='coolwarm',
1268
+ ax=ax,
1269
+ cbar=True,
1270
+ fmt='.2f',
1271
+ linewidths=0.5
1272
+ )
1273
+ plt.title('Correlation Matrix')
1274
+ st.pyplot(fig)
1275
+ else:
1276
+ st.warning("Please select at least two columns for correlation matrix.")
1277
+ else:
1278
+ st.warning("Not enough numerical columns for correlation analysis.")
1279
+
1280
+ # Add more visualization types as needed
1281
+
1282
+ # Prediction visualization
1283
+ if st.session_state.predictions is not None:
1284
+ st.subheader("Prediction Visualization")
1285
+
1286
+ # Distribution of predictions
1287
+ st.write("Distribution of Predictions:")
1288
+
1289
+ # Plot prediction distribution
1290
+ fig, ax = plt.subplots(figsize=(10, 6))
1291
+
1292
+ if 'toxicity_score' in st.session_state.predictions.columns:
1293
+ sns.histplot(st.session_state.predictions['toxicity_score'], bins=20, kde=True, ax=ax)
1294
+ plt.xlabel('Toxicity Score')
1295
+ plt.ylabel('Frequency')
1296
+ plt.title('Distribution of Toxicity Scores')
1297
+ st.pyplot(fig)
1298
+
1299
+ # Toxicity threshold analysis
1300
+ st.write("Toxicity Threshold Analysis:")
1301
+
1302
+ threshold = st.slider("Toxicity Threshold", 0.0, 1.0, 0.5, 0.05)
1303
+
1304
+ # Calculate metrics at different thresholds
1305
+ st.session_state.predictions['is_toxic_at_threshold'] = st.session_state.predictions[
1306
+ 'toxicity_score'] > threshold
1307
+
1308
+ toxic_at_threshold = st.session_state.predictions['is_toxic_at_threshold'].sum()
1309
+ total_predictions = len(st.session_state.predictions)
1310
+
1311
+ st.write(f"Threshold: {threshold}")
1312
+ st.write(
1313
+ f"Toxic comments: {toxic_at_threshold} ({toxic_at_threshold / total_predictions * 100:.2f}%)")
1314
+ st.write(
1315
+ f"Non-toxic comments: {total_predictions - toxic_at_threshold} ({(total_predictions - toxic_at_threshold) / total_predictions * 100:.2f}%)")
1316
+ else:
1317
+ st.warning("Toxicity scores not found in predictions.")
1318
+
1319
+ else:
1320
+ st.info("Please upload and preprocess data for visualization.")
1321
+
1322
+
1323
+ if __name__ == "__main__":
1324
+ main()