Spaces:
Sleeping
Sleeping
Commit
·
b2d89cf
1
Parent(s):
638e451
Initial deployment of Unsubscriber app with AI model
Browse files- README.md +22 -6
- app.py +118 -0
- final_optimized_model/added_tokens.json +3 -0
- final_optimized_model/config.json +35 -0
- final_optimized_model/model.safetensors +3 -0
- final_optimized_model/special_tokens_map.json +15 -0
- final_optimized_model/spm.model +3 -0
- final_optimized_model/tokenizer.json +0 -0
- final_optimized_model/tokenizer_config.json +58 -0
- final_optimized_model/training_info.json +21 -0
- ml_suite/__init__.py +19 -0
- ml_suite/__pycache__/__init__.cpython-312.pyc +0 -0
- ml_suite/__pycache__/config.cpython-312.pyc +0 -0
- ml_suite/__pycache__/predictor.cpython-312.pyc +0 -0
- ml_suite/__pycache__/task_utils.cpython-312.pyc +0 -0
- ml_suite/__pycache__/utils.cpython-312.pyc +0 -0
- ml_suite/advanced_predictor.py +283 -0
- ml_suite/config.py +149 -0
- ml_suite/data_preparator.py +458 -0
- ml_suite/model_trainer.py +498 -0
- ml_suite/models/base_transformer_cache/version.txt +1 -0
- ml_suite/models/fine_tuned_unsubscriber/config.json +24 -0
- ml_suite/models/fine_tuned_unsubscriber/model.safetensors +3 -0
- ml_suite/models/fine_tuned_unsubscriber/special_tokens_map.json +7 -0
- ml_suite/models/fine_tuned_unsubscriber/tokenizer.json +0 -0
- ml_suite/models/fine_tuned_unsubscriber/tokenizer_config.json +55 -0
- ml_suite/models/fine_tuned_unsubscriber/training_args.bin +3 -0
- ml_suite/models/fine_tuned_unsubscriber/vocab.txt +0 -0
- ml_suite/predictor.py +445 -0
- ml_suite/task_utils.py +332 -0
- ml_suite/utils.py +317 -0
- prepare_deployment.sh +20 -0
- requirements.txt +11 -0
README.md
CHANGED
@@ -1,12 +1,28 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: Email Unsubscribe Classifier
|
3 |
+
emoji: 📧
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: green
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.19.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
# Email Unsubscribe Classifier
|
13 |
+
|
14 |
+
This AI model classifies emails to determine if they are unsubscribe confirmations or important emails that should not be automatically processed.
|
15 |
+
|
16 |
+
## Features
|
17 |
+
- Classifies emails as 'unsubscribe' or 'important'
|
18 |
+
- Provides confidence scores for predictions
|
19 |
+
- Based on fine-tuned DeBERTa-v3-small model
|
20 |
+
- Trained on 20,000 email samples
|
21 |
+
|
22 |
+
## Usage
|
23 |
+
Simply paste your email content (including subject line) into the text box and click "Submit" to get a classification.
|
24 |
+
|
25 |
+
## Model Performance
|
26 |
+
- Accuracy: 100% on test set
|
27 |
+
- F1 Score: 1.0 for both classes
|
28 |
+
- Model size: 552MB
|
app.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import gradio as gr
|
4 |
+
import json
|
5 |
+
import base64
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
# Add parent directory to path to import the main app modules
|
9 |
+
parent_dir = str(Path(__file__).parent.parent)
|
10 |
+
sys.path.insert(0, parent_dir)
|
11 |
+
|
12 |
+
# Import the ML components from the main app
|
13 |
+
from ml_suite.predictor import initialize_predictor, get_ai_prediction_for_email, is_predictor_ready, get_model_status
|
14 |
+
|
15 |
+
# Initialize the predictor once when the app starts
|
16 |
+
print("Initializing AI model...")
|
17 |
+
initialize_predictor()
|
18 |
+
|
19 |
+
def parse_email_content(email_text):
|
20 |
+
"""Parse email content to extract subject and body"""
|
21 |
+
lines = email_text.strip().split('\n')
|
22 |
+
subject = ""
|
23 |
+
body = ""
|
24 |
+
|
25 |
+
# Simple parsing - look for Subject: line
|
26 |
+
for i, line in enumerate(lines):
|
27 |
+
if line.lower().startswith('subject:'):
|
28 |
+
subject = line[8:].strip()
|
29 |
+
body = '\n'.join(lines[i+1:]).strip()
|
30 |
+
break
|
31 |
+
|
32 |
+
if not subject and not body:
|
33 |
+
# If no subject line found, treat entire text as body
|
34 |
+
body = email_text
|
35 |
+
|
36 |
+
return subject, body
|
37 |
+
|
38 |
+
def classify_email(email_content):
|
39 |
+
"""Classify email using the AI model"""
|
40 |
+
if not email_content.strip():
|
41 |
+
return "Please enter email content to analyze."
|
42 |
+
|
43 |
+
# Check if model is ready
|
44 |
+
if not is_predictor_ready():
|
45 |
+
status = get_model_status()
|
46 |
+
return f"Model is not ready. Status: {status}"
|
47 |
+
|
48 |
+
# Parse email content
|
49 |
+
subject, body = parse_email_content(email_content)
|
50 |
+
|
51 |
+
# Create email data structure similar to the main app
|
52 |
+
email_data = {
|
53 |
+
'snippet': body[:200], # Gmail API typically provides snippets
|
54 |
+
'subject': subject,
|
55 |
+
'body': body,
|
56 |
+
'sender': 'demo@example.com', # Placeholder
|
57 |
+
'id': 'demo_id'
|
58 |
+
}
|
59 |
+
|
60 |
+
try:
|
61 |
+
# Get prediction
|
62 |
+
result = get_ai_prediction_for_email(email_data)
|
63 |
+
|
64 |
+
# Format the response
|
65 |
+
prediction = result.get('prediction', 'Unknown')
|
66 |
+
confidence = result.get('confidence', 0)
|
67 |
+
|
68 |
+
# Create formatted output
|
69 |
+
output = f"""
|
70 |
+
## Classification Result
|
71 |
+
|
72 |
+
**Category:** {prediction}
|
73 |
+
**Confidence:** {confidence:.2%}
|
74 |
+
|
75 |
+
### Analysis:
|
76 |
+
"""
|
77 |
+
|
78 |
+
if prediction == 'unsubscribe':
|
79 |
+
output += "✅ This email appears to be an unsubscribe confirmation or related to subscription management."
|
80 |
+
elif prediction == 'important':
|
81 |
+
output += "⚠️ This email appears to be important and should not be automatically processed."
|
82 |
+
else:
|
83 |
+
output += "❓ Unable to classify this email with high confidence."
|
84 |
+
|
85 |
+
# Add confidence interpretation
|
86 |
+
if confidence > 0.9:
|
87 |
+
output += f"\n\n*High confidence prediction ({confidence:.2%})*"
|
88 |
+
elif confidence > 0.7:
|
89 |
+
output += f"\n\n*Moderate confidence prediction ({confidence:.2%})*"
|
90 |
+
else:
|
91 |
+
output += f"\n\n*Low confidence prediction ({confidence:.2%})*"
|
92 |
+
|
93 |
+
return output
|
94 |
+
|
95 |
+
except Exception as e:
|
96 |
+
return f"Error during classification: {str(e)}"
|
97 |
+
|
98 |
+
# Create Gradio interface
|
99 |
+
demo = gr.Interface(
|
100 |
+
fn=classify_email,
|
101 |
+
inputs=gr.Textbox(
|
102 |
+
lines=10,
|
103 |
+
placeholder="Paste email content here...\n\nFormat:\nSubject: Your subscription has been cancelled\nBody text goes here...",
|
104 |
+
label="Email Content"
|
105 |
+
),
|
106 |
+
outputs=gr.Markdown(label="Classification Result"),
|
107 |
+
title="Email Unsubscribe Classifier",
|
108 |
+
description="This AI model classifies emails as either 'unsubscribe' confirmations or 'important' emails that should not be auto-processed.",
|
109 |
+
examples=[
|
110 |
+
["Subject: Your subscription has been cancelled\n\nHi there,\n\nWe're sorry to see you go! Your subscription to our newsletter has been successfully cancelled. You will no longer receive emails from us.\n\nBest regards,\nThe Team"],
|
111 |
+
["Subject: Important: Your account security update\n\nDear Customer,\n\nWe've detected unusual activity on your account. Please review your recent transactions and update your password immediately.\n\nThank you,\nSecurity Team"],
|
112 |
+
["Subject: You've been unsubscribed\n\nYou have been removed from our mailing list and will not receive any further emails from us."],
|
113 |
+
],
|
114 |
+
theme=gr.themes.Soft()
|
115 |
+
)
|
116 |
+
|
117 |
+
if __name__ == "__main__":
|
118 |
+
demo.launch()
|
final_optimized_model/added_tokens.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"[MASK]": 128000
|
3 |
+
}
|
final_optimized_model/config.json
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "microsoft/deberta-v3-small",
|
3 |
+
"architectures": [
|
4 |
+
"DebertaV2ForSequenceClassification"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"hidden_act": "gelu",
|
8 |
+
"hidden_dropout_prob": 0.1,
|
9 |
+
"hidden_size": 768,
|
10 |
+
"initializer_range": 0.02,
|
11 |
+
"intermediate_size": 3072,
|
12 |
+
"layer_norm_eps": 1e-07,
|
13 |
+
"max_position_embeddings": 512,
|
14 |
+
"max_relative_positions": -1,
|
15 |
+
"model_type": "deberta-v2",
|
16 |
+
"norm_rel_ebd": "layer_norm",
|
17 |
+
"num_attention_heads": 12,
|
18 |
+
"num_hidden_layers": 6,
|
19 |
+
"pad_token_id": 0,
|
20 |
+
"pooler_dropout": 0,
|
21 |
+
"pooler_hidden_act": "gelu",
|
22 |
+
"pooler_hidden_size": 768,
|
23 |
+
"pos_att_type": [
|
24 |
+
"p2c",
|
25 |
+
"c2p"
|
26 |
+
],
|
27 |
+
"position_biased_input": false,
|
28 |
+
"position_buckets": 256,
|
29 |
+
"relative_attention": true,
|
30 |
+
"share_att_key": true,
|
31 |
+
"torch_dtype": "float32",
|
32 |
+
"transformers_version": "4.36.2",
|
33 |
+
"type_vocab_size": 0,
|
34 |
+
"vocab_size": 128100
|
35 |
+
}
|
final_optimized_model/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4c39041f172b403117f415e2947236e097c2c5d4f5607cd332b81a602bdeea05
|
3 |
+
size 567598552
|
final_optimized_model/special_tokens_map.json
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "[CLS]",
|
3 |
+
"cls_token": "[CLS]",
|
4 |
+
"eos_token": "[SEP]",
|
5 |
+
"mask_token": "[MASK]",
|
6 |
+
"pad_token": "[PAD]",
|
7 |
+
"sep_token": "[SEP]",
|
8 |
+
"unk_token": {
|
9 |
+
"content": "[UNK]",
|
10 |
+
"lstrip": false,
|
11 |
+
"normalized": true,
|
12 |
+
"rstrip": false,
|
13 |
+
"single_word": false
|
14 |
+
}
|
15 |
+
}
|
final_optimized_model/spm.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c679fbf93643d19aab7ee10c0b99e460bdbc02fedf34b92b05af343b4af586fd
|
3 |
+
size 2464616
|
final_optimized_model/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
final_optimized_model/tokenizer_config.json
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"0": {
|
4 |
+
"content": "[PAD]",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"1": {
|
12 |
+
"content": "[CLS]",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"2": {
|
20 |
+
"content": "[SEP]",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
},
|
27 |
+
"3": {
|
28 |
+
"content": "[UNK]",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": true,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false,
|
33 |
+
"special": true
|
34 |
+
},
|
35 |
+
"128000": {
|
36 |
+
"content": "[MASK]",
|
37 |
+
"lstrip": false,
|
38 |
+
"normalized": false,
|
39 |
+
"rstrip": false,
|
40 |
+
"single_word": false,
|
41 |
+
"special": true
|
42 |
+
}
|
43 |
+
},
|
44 |
+
"bos_token": "[CLS]",
|
45 |
+
"clean_up_tokenization_spaces": true,
|
46 |
+
"cls_token": "[CLS]",
|
47 |
+
"do_lower_case": false,
|
48 |
+
"eos_token": "[SEP]",
|
49 |
+
"mask_token": "[MASK]",
|
50 |
+
"model_max_length": 1000000000000000019884624838656,
|
51 |
+
"pad_token": "[PAD]",
|
52 |
+
"sep_token": "[SEP]",
|
53 |
+
"sp_model_kwargs": {},
|
54 |
+
"split_by_punct": false,
|
55 |
+
"tokenizer_class": "DebertaV2Tokenizer",
|
56 |
+
"unk_token": "[UNK]",
|
57 |
+
"vocab_type": "spm"
|
58 |
+
}
|
final_optimized_model/training_info.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_name": "microsoft/deberta-v3-small",
|
3 |
+
"training_date": "2025-05-27T02:45:24.492640",
|
4 |
+
"training_hours": 7.459452472222223,
|
5 |
+
"samples": 20000,
|
6 |
+
"final_metrics": {
|
7 |
+
"eval_loss": 7.772326353006065e-05,
|
8 |
+
"eval_accuracy": 1.0,
|
9 |
+
"eval_precision_important": 1.0,
|
10 |
+
"eval_recall_important": 1.0,
|
11 |
+
"eval_f1_important": 1.0,
|
12 |
+
"eval_precision_unsub": 1.0,
|
13 |
+
"eval_recall_unsub": 1.0,
|
14 |
+
"eval_f1_unsub": 1.0,
|
15 |
+
"eval_false_positives": 0,
|
16 |
+
"eval_runtime": 433.8438,
|
17 |
+
"eval_samples_per_second": 6.915,
|
18 |
+
"eval_steps_per_second": 0.217,
|
19 |
+
"epoch": 3.01
|
20 |
+
}
|
21 |
+
}
|
ml_suite/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Machine Learning module for Gmail Unsubscriber application.
|
3 |
+
|
4 |
+
This module provides AI-powered classification of emails as 'unsubscribable' or 'important'.
|
5 |
+
It includes components for data preparation, model training, and prediction, with user
|
6 |
+
control over the AI lifecycle through an in-app AI panel.
|
7 |
+
|
8 |
+
The design emphasizes:
|
9 |
+
- User control over AI data preparation and training
|
10 |
+
- Seamless integration with the existing application
|
11 |
+
- Transparency in AI operations and decisions
|
12 |
+
- Graceful degradation when AI components are unavailable
|
13 |
+
"""
|
14 |
+
|
15 |
+
# Import configuration to ensure directories are created
|
16 |
+
from . import config
|
17 |
+
|
18 |
+
# Version information
|
19 |
+
__version__ = "0.1.0"
|
ml_suite/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (793 Bytes). View file
|
|
ml_suite/__pycache__/config.cpython-312.pyc
ADDED
Binary file (6.31 kB). View file
|
|
ml_suite/__pycache__/predictor.cpython-312.pyc
ADDED
Binary file (14.5 kB). View file
|
|
ml_suite/__pycache__/task_utils.cpython-312.pyc
ADDED
Binary file (13.1 kB). View file
|
|
ml_suite/__pycache__/utils.cpython-312.pyc
ADDED
Binary file (10 kB). View file
|
|
ml_suite/advanced_predictor.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Advanced predictor with support for multiple models and ensemble predictions
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
from transformers import (
|
9 |
+
AutoTokenizer,
|
10 |
+
AutoModelForSequenceClassification,
|
11 |
+
TextClassificationPipeline
|
12 |
+
)
|
13 |
+
import logging
|
14 |
+
from typing import Dict, List, Tuple, Optional
|
15 |
+
|
16 |
+
# Configure logging
|
17 |
+
logging.basicConfig(level=logging.INFO)
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
class AdvancedPredictor:
|
21 |
+
"""Advanced predictor with ensemble support and confidence calibration"""
|
22 |
+
|
23 |
+
def __init__(self, model_paths: List[str], weights: Optional[List[float]] = None):
|
24 |
+
"""
|
25 |
+
Initialize with multiple models for ensemble prediction
|
26 |
+
|
27 |
+
Args:
|
28 |
+
model_paths: List of paths to model directories
|
29 |
+
weights: Optional weights for each model (must sum to 1.0)
|
30 |
+
"""
|
31 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
32 |
+
self.models = []
|
33 |
+
self.tokenizers = []
|
34 |
+
self.pipelines = []
|
35 |
+
|
36 |
+
# Load all models
|
37 |
+
for path in model_paths:
|
38 |
+
if os.path.exists(path):
|
39 |
+
logger.info(f"Loading model from {path}")
|
40 |
+
tokenizer = AutoTokenizer.from_pretrained(path)
|
41 |
+
model = AutoModelForSequenceClassification.from_pretrained(path)
|
42 |
+
model = model.to(self.device)
|
43 |
+
model.eval()
|
44 |
+
|
45 |
+
# Create pipeline
|
46 |
+
pipeline = TextClassificationPipeline(
|
47 |
+
model=model,
|
48 |
+
tokenizer=tokenizer,
|
49 |
+
device=0 if torch.cuda.is_available() else -1,
|
50 |
+
top_k=None,
|
51 |
+
function_to_apply="sigmoid"
|
52 |
+
)
|
53 |
+
|
54 |
+
self.models.append(model)
|
55 |
+
self.tokenizers.append(tokenizer)
|
56 |
+
self.pipelines.append(pipeline)
|
57 |
+
else:
|
58 |
+
logger.warning(f"Model path not found: {path}")
|
59 |
+
|
60 |
+
if not self.models:
|
61 |
+
raise ValueError("No models loaded successfully")
|
62 |
+
|
63 |
+
# Set weights
|
64 |
+
if weights:
|
65 |
+
assert len(weights) == len(self.models), "Number of weights must match number of models"
|
66 |
+
assert abs(sum(weights) - 1.0) < 1e-6, "Weights must sum to 1.0"
|
67 |
+
self.weights = weights
|
68 |
+
else:
|
69 |
+
# Equal weights by default
|
70 |
+
self.weights = [1.0 / len(self.models)] * len(self.models)
|
71 |
+
|
72 |
+
logger.info(f"Initialized with {len(self.models)} models")
|
73 |
+
|
74 |
+
def predict(self, text: str, return_all_scores: bool = False) -> Dict:
|
75 |
+
"""
|
76 |
+
Make ensemble prediction
|
77 |
+
|
78 |
+
Args:
|
79 |
+
text: Email text to classify
|
80 |
+
return_all_scores: Whether to return individual model scores
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
Dictionary with prediction results
|
84 |
+
"""
|
85 |
+
# Preprocess text
|
86 |
+
text = self._preprocess_email(text)
|
87 |
+
|
88 |
+
# Get predictions from all models
|
89 |
+
all_predictions = []
|
90 |
+
for pipeline in self.pipelines:
|
91 |
+
try:
|
92 |
+
result = pipeline(text)
|
93 |
+
all_predictions.append(result)
|
94 |
+
except Exception as e:
|
95 |
+
logger.error(f"Error in prediction: {e}")
|
96 |
+
continue
|
97 |
+
|
98 |
+
if not all_predictions:
|
99 |
+
return {
|
100 |
+
"label": "IMPORTANT",
|
101 |
+
"score": 0.5,
|
102 |
+
"confidence": "low",
|
103 |
+
"error": "Prediction failed"
|
104 |
+
}
|
105 |
+
|
106 |
+
# Aggregate predictions
|
107 |
+
ensemble_scores = self._aggregate_predictions(all_predictions)
|
108 |
+
|
109 |
+
# Determine final prediction
|
110 |
+
unsub_score = ensemble_scores.get("UNSUBSCRIBABLE", 0.5)
|
111 |
+
important_score = ensemble_scores.get("IMPORTANT", 0.5)
|
112 |
+
|
113 |
+
# Apply confidence calibration
|
114 |
+
calibrated_unsub = self._calibrate_confidence(unsub_score)
|
115 |
+
|
116 |
+
# Determine label
|
117 |
+
if calibrated_unsub > 0.75: # High confidence threshold
|
118 |
+
label = "UNSUBSCRIBABLE"
|
119 |
+
score = calibrated_unsub
|
120 |
+
else:
|
121 |
+
label = "IMPORTANT"
|
122 |
+
score = important_score
|
123 |
+
|
124 |
+
# Confidence level
|
125 |
+
if score > 0.9:
|
126 |
+
confidence = "high"
|
127 |
+
elif score > 0.7:
|
128 |
+
confidence = "medium"
|
129 |
+
else:
|
130 |
+
confidence = "low"
|
131 |
+
|
132 |
+
result = {
|
133 |
+
"label": label,
|
134 |
+
"score": float(score),
|
135 |
+
"confidence": confidence,
|
136 |
+
"raw_scores": {
|
137 |
+
"UNSUBSCRIBABLE": float(unsub_score),
|
138 |
+
"IMPORTANT": float(important_score)
|
139 |
+
}
|
140 |
+
}
|
141 |
+
|
142 |
+
if return_all_scores:
|
143 |
+
result["model_predictions"] = all_predictions
|
144 |
+
|
145 |
+
return result
|
146 |
+
|
147 |
+
def _preprocess_email(self, text: str) -> str:
|
148 |
+
"""Advanced email preprocessing"""
|
149 |
+
# Handle subject extraction
|
150 |
+
if "Subject:" in text:
|
151 |
+
parts = text.split("Subject:", 1)
|
152 |
+
if len(parts) > 1:
|
153 |
+
subject = parts[1].split("\n")[0].strip()
|
154 |
+
body = parts[1][len(subject):].strip()
|
155 |
+
# Emphasize subject
|
156 |
+
text = f"Email Subject: {subject}. Email Body: {body}"
|
157 |
+
|
158 |
+
# Clean text
|
159 |
+
text = text.replace("\\n", " ").replace("\\t", " ")
|
160 |
+
text = " ".join(text.split())
|
161 |
+
|
162 |
+
# Truncate if too long
|
163 |
+
if len(text) > 2000:
|
164 |
+
text = text[:2000] + "..."
|
165 |
+
|
166 |
+
return text
|
167 |
+
|
168 |
+
def _aggregate_predictions(self, predictions: List) -> Dict[str, float]:
|
169 |
+
"""Aggregate predictions from multiple models using weighted voting"""
|
170 |
+
aggregated = {"UNSUBSCRIBABLE": 0.0, "IMPORTANT": 0.0}
|
171 |
+
|
172 |
+
for i, pred_list in enumerate(predictions):
|
173 |
+
weight = self.weights[i]
|
174 |
+
|
175 |
+
# Handle different prediction formats
|
176 |
+
if isinstance(pred_list, list) and pred_list:
|
177 |
+
for pred in pred_list:
|
178 |
+
label = pred.get("label", "").upper()
|
179 |
+
score = pred.get("score", 0.5)
|
180 |
+
|
181 |
+
if label in aggregated:
|
182 |
+
aggregated[label] += score * weight
|
183 |
+
|
184 |
+
# Normalize
|
185 |
+
total = sum(aggregated.values())
|
186 |
+
if total > 0:
|
187 |
+
for key in aggregated:
|
188 |
+
aggregated[key] /= total
|
189 |
+
|
190 |
+
return aggregated
|
191 |
+
|
192 |
+
def _calibrate_confidence(self, score: float, temperature: float = 1.2) -> float:
|
193 |
+
"""Apply temperature scaling for confidence calibration"""
|
194 |
+
# Convert to logit
|
195 |
+
epsilon = 1e-7
|
196 |
+
score = np.clip(score, epsilon, 1 - epsilon)
|
197 |
+
logit = np.log(score / (1 - score))
|
198 |
+
|
199 |
+
# Apply temperature scaling
|
200 |
+
calibrated_logit = logit / temperature
|
201 |
+
|
202 |
+
# Convert back to probability
|
203 |
+
calibrated_score = 1 / (1 + np.exp(-calibrated_logit))
|
204 |
+
|
205 |
+
return float(calibrated_score)
|
206 |
+
|
207 |
+
def predict_batch(self, texts: List[str]) -> List[Dict]:
|
208 |
+
"""Predict multiple emails efficiently"""
|
209 |
+
results = []
|
210 |
+
|
211 |
+
# Process in batches for efficiency
|
212 |
+
batch_size = 8
|
213 |
+
for i in range(0, len(texts), batch_size):
|
214 |
+
batch = texts[i:i + batch_size]
|
215 |
+
batch_results = [self.predict(text) for text in batch]
|
216 |
+
results.extend(batch_results)
|
217 |
+
|
218 |
+
return results
|
219 |
+
|
220 |
+
def get_feature_importance(self, text: str) -> Dict:
|
221 |
+
"""Get feature importance for explainability"""
|
222 |
+
# This is a simplified version - in production, use SHAP or LIME
|
223 |
+
important_keywords = [
|
224 |
+
"unsubscribe", "opt out", "preferences", "newsletter",
|
225 |
+
"promotional", "marketing", "deal", "offer", "sale"
|
226 |
+
]
|
227 |
+
|
228 |
+
text_lower = text.lower()
|
229 |
+
found_keywords = [kw for kw in important_keywords if kw in text_lower]
|
230 |
+
|
231 |
+
return {
|
232 |
+
"important_features": found_keywords,
|
233 |
+
"feature_count": len(found_keywords)
|
234 |
+
}
|
235 |
+
|
236 |
+
|
237 |
+
def create_advanced_predictor():
|
238 |
+
"""Factory function to create predictor with best available models"""
|
239 |
+
model_paths = []
|
240 |
+
|
241 |
+
# Check for advanced model first
|
242 |
+
if os.path.exists("./advanced_unsubscriber_model"):
|
243 |
+
model_paths.append("./advanced_unsubscriber_model")
|
244 |
+
|
245 |
+
# Check for optimized model
|
246 |
+
if os.path.exists("./optimized_model"):
|
247 |
+
model_paths.append("./optimized_model")
|
248 |
+
|
249 |
+
# Fallback to original model
|
250 |
+
if os.path.exists("./ml_suite/models/fine_tuned_unsubscriber"):
|
251 |
+
model_paths.append("./ml_suite/models/fine_tuned_unsubscriber")
|
252 |
+
|
253 |
+
if not model_paths:
|
254 |
+
raise ValueError("No trained models found")
|
255 |
+
|
256 |
+
# Use ensemble if multiple models available
|
257 |
+
if len(model_paths) > 1:
|
258 |
+
logger.info(f"Creating ensemble predictor with {len(model_paths)} models")
|
259 |
+
# Give higher weight to advanced model
|
260 |
+
weights = [0.6, 0.4] if len(model_paths) == 2 else None
|
261 |
+
return AdvancedPredictor(model_paths, weights)
|
262 |
+
else:
|
263 |
+
logger.info(f"Creating single model predictor")
|
264 |
+
return AdvancedPredictor(model_paths)
|
265 |
+
|
266 |
+
|
267 |
+
# Example usage
|
268 |
+
if __name__ == "__main__":
|
269 |
+
# Test the predictor
|
270 |
+
predictor = create_advanced_predictor()
|
271 |
+
|
272 |
+
test_emails = [
|
273 |
+
"Subject: 50% OFF Everything! Limited time offer. Click here to shop now. Unsubscribe from promotional emails.",
|
274 |
+
"Subject: Security Alert: New login detected. We noticed a login from a new device. If this wasn't you, secure your account.",
|
275 |
+
"Subject: Your monthly newsletter is here! Check out our latest articles and tips. Manage your email preferences.",
|
276 |
+
]
|
277 |
+
|
278 |
+
for email in test_emails:
|
279 |
+
result = predictor.predict(email, return_all_scores=True)
|
280 |
+
print(f"\nEmail: {email[:100]}...")
|
281 |
+
print(f"Prediction: {result['label']}")
|
282 |
+
print(f"Confidence: {result['confidence']} ({result['score']:.2%})")
|
283 |
+
print(f"Raw scores: {result['raw_scores']}")
|
ml_suite/config.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Centralized configuration for Gmail Unsubscriber AI Suite.
|
3 |
+
|
4 |
+
This module defines all configuration parameters for the ML components, including:
|
5 |
+
- Directory paths for models, datasets, and task status
|
6 |
+
- Hugging Face cache configuration
|
7 |
+
- Model specifications
|
8 |
+
- Data preparation parameters
|
9 |
+
- Training hyperparameters
|
10 |
+
- User data collection and personalization parameters
|
11 |
+
|
12 |
+
All directories are automatically created when this module is imported.
|
13 |
+
"""
|
14 |
+
|
15 |
+
import os
|
16 |
+
|
17 |
+
# --- Base Path Configuration ---
|
18 |
+
ML_SUITE_DIR = os.path.dirname(os.path.abspath(__file__))
|
19 |
+
PROJECT_ROOT = os.path.dirname(ML_SUITE_DIR)
|
20 |
+
|
21 |
+
# --- Cache and Model Storage ---
|
22 |
+
MODELS_DIR = os.path.join(ML_SUITE_DIR, "models")
|
23 |
+
BASE_TRANSFORMER_CACHE_DIR = os.path.join(MODELS_DIR, "base_transformer_cache")
|
24 |
+
# FINE_TUNED_MODEL_DIR = os.path.join(MODELS_DIR, "fine_tuned_unsubscriber") # Old model
|
25 |
+
FINE_TUNED_MODEL_DIR = os.path.join(PROJECT_ROOT, "final_optimized_model") # New trained model
|
26 |
+
|
27 |
+
# Set Hugging Face environment variables to use project-local cache
|
28 |
+
os.environ['HF_HOME'] = BASE_TRANSFORMER_CACHE_DIR
|
29 |
+
os.environ['TRANSFORMERS_CACHE'] = BASE_TRANSFORMER_CACHE_DIR
|
30 |
+
os.environ['HF_DATASETS_CACHE'] = os.path.join(BASE_TRANSFORMER_CACHE_DIR, 'datasets')
|
31 |
+
os.environ['HF_METRICS_CACHE'] = os.path.join(BASE_TRANSFORMER_CACHE_DIR, 'metrics')
|
32 |
+
|
33 |
+
# --- Dataset Storage ---
|
34 |
+
DATASETS_DIR = os.path.join(ML_SUITE_DIR, "datasets")
|
35 |
+
RAW_DATASETS_DIR = os.path.join(DATASETS_DIR, "raw")
|
36 |
+
EXTRACTED_DATASETS_DIR = os.path.join(DATASETS_DIR, "extracted")
|
37 |
+
PROCESSED_DATASETS_DIR = os.path.join(DATASETS_DIR, "processed")
|
38 |
+
PREPARED_DATA_FILE = os.path.join(PROCESSED_DATASETS_DIR, "unsubscriber_training_data.csv")
|
39 |
+
DATA_COLUMNS_SCHEMA = ['text', 'label'] # Schema for the training CSV
|
40 |
+
|
41 |
+
# --- Task Status Storage ---
|
42 |
+
TASK_STATUS_DIR = os.path.join(ML_SUITE_DIR, "task_status")
|
43 |
+
DATA_PREP_STATUS_FILE = os.path.join(TASK_STATUS_DIR, "data_preparation_status.json")
|
44 |
+
MODEL_TRAIN_STATUS_FILE = os.path.join(TASK_STATUS_DIR, "model_training_status.json")
|
45 |
+
PERSONALIZED_TRAIN_STATUS_FILE = os.path.join(TASK_STATUS_DIR, "personalized_training_status.json")
|
46 |
+
|
47 |
+
# --- User Data Collection and Personalization ---
|
48 |
+
USER_DATA_DIR = os.path.join(ML_SUITE_DIR, "user_data")
|
49 |
+
USER_FEEDBACK_DIR = os.path.join(USER_DATA_DIR, "feedback")
|
50 |
+
USER_MODELS_DIR = os.path.join(USER_DATA_DIR, "models")
|
51 |
+
USER_DATASETS_DIR = os.path.join(USER_DATA_DIR, "datasets")
|
52 |
+
|
53 |
+
# User feedback collection configuration
|
54 |
+
USER_FEEDBACK_FILE = os.path.join(USER_FEEDBACK_DIR, "user_feedback.csv")
|
55 |
+
FEEDBACK_COLUMNS_SCHEMA = ['email_id', 'text', 'predicted_label', 'predicted_confidence', 'user_feedback', 'timestamp', 'session_id']
|
56 |
+
|
57 |
+
# Personalized model configuration
|
58 |
+
PERSONALIZED_MODEL_DIR_TEMPLATE = os.path.join(USER_MODELS_DIR, "{user_id}")
|
59 |
+
PERSONALIZED_MODEL_FILE_TEMPLATE = os.path.join(PERSONALIZED_MODEL_DIR_TEMPLATE, "model.pt")
|
60 |
+
PERSONALIZED_MODEL_INFO_TEMPLATE = os.path.join(PERSONALIZED_MODEL_DIR_TEMPLATE, "model_info.json")
|
61 |
+
PERSONALIZED_DATASET_FILE_TEMPLATE = os.path.join(USER_DATASETS_DIR, "{user_id}_training_data.csv")
|
62 |
+
|
63 |
+
# Personalization hyperparameters
|
64 |
+
MIN_FEEDBACK_ENTRIES_FOR_PERSONALIZATION = 10 # Minimum number of user feedback entries required for personalization
|
65 |
+
PERSONALIZATION_WEIGHT = 0.7 # Weight given to user feedback vs. base model (higher = more personalized)
|
66 |
+
PERSONALIZATION_EPOCHS = 2 # Number of epochs for fine-tuning a personalized model
|
67 |
+
|
68 |
+
# --- Directory Creation (Updated with User Data directories) ---
|
69 |
+
for dir_path in [MODELS_DIR, BASE_TRANSFORMER_CACHE_DIR, FINE_TUNED_MODEL_DIR,
|
70 |
+
RAW_DATASETS_DIR, EXTRACTED_DATASETS_DIR, PROCESSED_DATASETS_DIR, TASK_STATUS_DIR,
|
71 |
+
USER_DATA_DIR, USER_FEEDBACK_DIR, USER_MODELS_DIR, USER_DATASETS_DIR]:
|
72 |
+
os.makedirs(dir_path, exist_ok=True)
|
73 |
+
|
74 |
+
# --- Transformer Model Configuration ---
|
75 |
+
# Choice: DistilBERT offers a good balance of performance and resource efficiency.
|
76 |
+
# Other candidates: 'bert-base-uncased', 'roberta-base', 'google/electra-small-discriminator'.
|
77 |
+
# The choice impacts download size, training time, and inference speed.
|
78 |
+
PRE_TRAINED_MODEL_NAME = "distilbert-base-uncased"
|
79 |
+
|
80 |
+
# --- Data Preparation Parameters ---
|
81 |
+
# Define sources for public email data. URLs and types guide the preparator.
|
82 |
+
PUBLIC_DATASETS_INFO = {
|
83 |
+
"spamassassin_easy_ham_2003": {
|
84 |
+
"url": "https://spamassassin.apache.org/publiccorpus/20030228_easy_ham.tar.bz2",
|
85 |
+
"type": "important_leaning", # Expected dominant class after heuristic application
|
86 |
+
"extract_folder_name": "spamassassin_easy_ham_2003"
|
87 |
+
},
|
88 |
+
"spamassassin_spam_2003": {
|
89 |
+
"url": "https://spamassassin.apache.org/publiccorpus/20030228_spam.tar.bz2",
|
90 |
+
"type": "unsubscribable_leaning",
|
91 |
+
"extract_folder_name": "spamassassin_spam_2003"
|
92 |
+
},
|
93 |
+
# Consider adding more diverse datasets like:
|
94 |
+
# - Enron (requires significant parsing and ethical review for a suitable subset)
|
95 |
+
# - Public mailing list archives (e.g., from Apache Software Foundation, carefully selected for relevance)
|
96 |
+
}
|
97 |
+
MIN_TEXT_LENGTH_FOR_TRAINING = 60 # Emails shorter than this (after cleaning) are likely not useful.
|
98 |
+
MAX_SAMPLES_PER_RAW_DATASET = 7500 # Limits processing time for initial data prep. Can be increased.
|
99 |
+
EMAIL_SNIPPET_LENGTH_FOR_MODEL = 1024 # Max characters from email body to combine with subject for model input.
|
100 |
+
|
101 |
+
# --- Training Hyperparameters & Configuration ---
|
102 |
+
NUM_LABELS = 2 # Binary classification: Unsubscribable vs. Important
|
103 |
+
LABEL_IMPORTANT_ID = 0
|
104 |
+
LABEL_UNSUBSCRIBABLE_ID = 1
|
105 |
+
ID_TO_LABEL_MAP = {LABEL_IMPORTANT_ID: "IMPORTANT", LABEL_UNSUBSCRIBABLE_ID: "UNSUBSCRIBABLE"}
|
106 |
+
LABEL_TO_ID_MAP = {"IMPORTANT": LABEL_IMPORTANT_ID, "UNSUBSCRIBABLE": LABEL_UNSUBSCRIBABLE_ID}
|
107 |
+
|
108 |
+
MAX_SEQ_LENGTH = 512 # Max token sequence length for Transformer. Impacts memory and context window.
|
109 |
+
TRAIN_BATCH_SIZE = 16 # Batch size for training. Reduced for GTX 1650 (4GB VRAM)
|
110 |
+
EVAL_BATCH_SIZE = 32 # Batch size for evaluation. Reduced for GTX 1650
|
111 |
+
NUM_TRAIN_EPOCHS = 8 # Number of full passes through the training data (increased for better learning).
|
112 |
+
LEARNING_RATE = 1e-5 # AdamW optimizer learning rate, slightly reduced for more stable training.
|
113 |
+
WEIGHT_DECAY = 0.02 # Regularization parameter.
|
114 |
+
WARMUP_STEPS_RATIO = 0.15 # Ratio of total training steps for learning rate warmup.
|
115 |
+
TEST_SPLIT_SIZE = 0.2 # Proportion of data for the evaluation set (increased for better validation).
|
116 |
+
|
117 |
+
# Hugging Face Trainer Arguments
|
118 |
+
EVALUATION_STRATEGY = "epoch" # Evaluate at the end of each epoch.
|
119 |
+
SAVE_STRATEGY = "epoch" # Save model checkpoint at the end of each epoch.
|
120 |
+
LOAD_BEST_MODEL_AT_END = True # Reload the best model (based on metric_for_best_model) at the end of training.
|
121 |
+
METRIC_FOR_BEST_MODEL = "f1_unsub" # Focus on F1 for the "unsubscribable" class.
|
122 |
+
FP16_TRAINING = True # Enable mixed-precision training if a CUDA GPU is available and supports it.
|
123 |
+
EARLY_STOPPING_PATIENCE = 3 # Stop training if metric_for_best_model doesn't improve for this many epochs.
|
124 |
+
EARLY_STOPPING_THRESHOLD = 0.001 # Minimum change to be considered an improvement.
|
125 |
+
|
126 |
+
# --- AI User Preferences (Defaults stored in JS, but can be defined here for reference) ---
|
127 |
+
DEFAULT_AI_ENABLED_ON_SCAN = True
|
128 |
+
DEFAULT_AI_CONFIDENCE_THRESHOLD = 0.5 # (50%) - Balanced threshold for optimal precision/recall
|
129 |
+
|
130 |
+
# --- API Endpoint Configuration for Backend Integration ---
|
131 |
+
API_ENDPOINTS = {
|
132 |
+
"submit_feedback": "/api/ai/feedback",
|
133 |
+
"get_feedback_stats": "/api/ai/feedback/stats",
|
134 |
+
"train_personalized": "/api/ai/train_personalized",
|
135 |
+
"reset_user_data": "/api/ai/user_data/reset",
|
136 |
+
"export_user_data": "/api/ai/user_data/export",
|
137 |
+
"import_user_data": "/api/ai/user_data/import"
|
138 |
+
}
|
139 |
+
# --- Advanced Transformer Configuration (2024 Research) ---
|
140 |
+
# Based on 2024 research showing RoBERTa and DistilBERT achieve 99%+ accuracy
|
141 |
+
TRANSFORMER_MODEL_NAME = "distilbert-base-uncased" # Optimal balance of speed and accuracy
|
142 |
+
USE_MIXED_PRECISION = True # FP16 training for efficiency
|
143 |
+
GRADIENT_ACCUMULATION_STEPS = 4 # Increased for GTX 1650 to simulate larger batch size
|
144 |
+
MAX_GRAD_NORM = 1.0 # Gradient clipping for stability
|
145 |
+
LABEL_SMOOTHING_FACTOR = 0.1 # Reduce overconfidence
|
146 |
+
SAVE_TOTAL_LIMIT = 3 # Keep only best 3 checkpoints
|
147 |
+
LOGGING_STEPS = 50 # Frequent logging for monitoring
|
148 |
+
EVAL_STEPS = 100 # Regular evaluation during training
|
149 |
+
DATALOADER_NUM_WORKERS = 2 # Reduced for GTX 1650 to avoid memory issues
|
ml_suite/data_preparator.py
ADDED
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Enhanced data preparation module for the Gmail Unsubscriber AI suite.
|
3 |
+
|
4 |
+
This module provides robust data preparation with fallback capabilities:
|
5 |
+
- Primary: Download and process public email datasets
|
6 |
+
- Fallback: Create comprehensive synthetic dataset when downloads fail
|
7 |
+
- Robust error handling and recovery
|
8 |
+
- Quality validation and balancing
|
9 |
+
"""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import re
|
13 |
+
import csv
|
14 |
+
import time
|
15 |
+
import json
|
16 |
+
import shutil
|
17 |
+
import tarfile
|
18 |
+
import requests
|
19 |
+
import tempfile
|
20 |
+
import random
|
21 |
+
import email
|
22 |
+
import email.parser
|
23 |
+
import email.policy
|
24 |
+
from typing import Dict, List, Tuple, Optional, Any, Union, Set
|
25 |
+
from collections import Counter, defaultdict
|
26 |
+
from bs4 import BeautifulSoup
|
27 |
+
|
28 |
+
# Local imports
|
29 |
+
from . import config
|
30 |
+
from . import utils
|
31 |
+
from .task_utils import AiTaskLogger
|
32 |
+
|
33 |
+
|
34 |
+
def create_comprehensive_fallback_dataset(task_logger: AiTaskLogger) -> List[Tuple[str, int]]:
|
35 |
+
"""
|
36 |
+
Create a comprehensive fallback training dataset when external downloads fail.
|
37 |
+
|
38 |
+
Returns balanced, high-quality training examples for email classification.
|
39 |
+
"""
|
40 |
+
task_logger.info("Creating comprehensive fallback training dataset...")
|
41 |
+
|
42 |
+
# Unsubscribable emails (marketing, promotional, newsletters)
|
43 |
+
unsubscribable_emails = [
|
44 |
+
# Marketing newsletters
|
45 |
+
("Subject: Weekly Newsletter - Best Deals & Updates Inside!\n\nDiscover this week's top deals and special offers exclusively for our subscribers. Check out our latest product launches and limited-time promotions. If you no longer wish to receive these weekly newsletters, click here to unsubscribe from all marketing communications.", 1),
|
46 |
+
("Subject: Flash Sale Alert - 48 Hours Only!\n\nDon't miss our biggest flash sale of the season! Save up to 60% on thousands of items with free shipping on orders over $50. This exclusive offer expires in 48 hours. To opt out of flash sale notifications, update your email preferences here.", 1),
|
47 |
+
("Subject: New Arrivals Handpicked Just For You\n\nBased on your browsing history and previous purchases, our style experts have curated these new arrivals specifically for your taste. Discover the latest trends and exclusive styles before anyone else. Unsubscribe from personalized product recommendations here.", 1),
|
48 |
+
("Subject: Member Exclusive - VIP Early Access Sale\n\nAs a valued VIP member, enjoy exclusive early access to our annual clearance sale before it opens to the public. Get first pick of the best deals with an additional 20% off already reduced prices. Manage your VIP membership communications here.", 1),
|
49 |
+
("Subject: Limited Weekend Offer - Free Premium Shipping\n\nThis weekend only! Enjoy free premium shipping on all orders, no minimum purchase required. Perfect time to stock up on your favorites or try something new with expedited delivery. Stop weekend promotion emails here.", 1),
|
50 |
+
|
51 |
+
# Product promotions and recommendations
|
52 |
+
("Subject: MASSIVE CLEARANCE EVENT - Up to 80% OFF Everything\n\nOur biggest clearance event ever! Save up to 80% on thousands of items across all categories. Everything must go to make room for new seasonal inventory. Don't miss these incredible once-a-year deals. Unsubscribe from clearance sale alerts here.", 1),
|
53 |
+
("Subject: Back in Stock Alert - Your Wishlist Items Available\n\nGreat news! Multiple items from your wishlist are now back in stock and ready for immediate shipping. These popular products tend to sell out quickly, so we recommend ordering soon. Turn off wishlist availability notifications here.", 1),
|
54 |
+
("Subject: Similar Customers Also Purchased These Items\n\nBased on your recent order, here are products that customers with similar purchasing patterns have been loving. Discover new favorites and trending items curated just for you. Opt out of 'customers also bought' recommendation emails here.", 1),
|
55 |
+
("Subject: Complete Your Collection - Matching Set Available\n\nWe noticed you recently purchased part of our bestselling collection. Complete your set with these perfectly coordinating pieces now available at a special bundle price with free gift wrapping. Unsubscribe from collection completion suggestions here.", 1),
|
56 |
+
("Subject: Reward Points Expiring Soon - Redeem Now!\n\nDon't let your 750 reward points expire! Use them before the end of this month for instant discounts, free products, or exclusive member perks. Your points are worth $37.50 in savings. Manage reward points notifications here.", 1),
|
57 |
+
|
58 |
+
# Subscription services and content
|
59 |
+
("Subject: Your Weekly Meal Plan - Delicious Recipes Inside\n\nThis week's curated meal plan features 7 delicious, nutritionist-approved recipes designed to save you time and money. Includes shopping lists and prep instructions for busy weeknight dinners. Unsubscribe from weekly meal plan emails here.", 1),
|
60 |
+
("Subject: Daily Market Brief - Today's Financial Highlights\n\nStay informed with today's most important market movements and financial news. Our expert analysts provide insights on stocks, crypto, and economic indicators that matter to your portfolio. Stop daily market briefing emails here.", 1),
|
61 |
+
("Subject: Fitness Challenge Week 4 - You're Crushing It!\n\nAmazing progress! Week 4 of your personalized fitness challenge includes new strength training exercises and advanced nutrition tips to accelerate your results. Keep up the fantastic work! Opt out of fitness challenge updates here.", 1),
|
62 |
+
("Subject: Learning Path Update - Your Next Course is Ready\n\nBased on your completed coursework and learning goals, we've prepared your next recommended course in advanced digital marketing. Continue building your professional skills with expert-led content. Manage learning recommendation preferences here.", 1),
|
63 |
+
("Subject: Podcast Weekly - New Episodes You'll Love\n\nThis week's podcast episodes feature industry leaders discussing the latest trends in technology, business, and personal development. Plus exclusive interviews and behind-the-scenes content. Unsubscribe from podcast update emails here.", 1),
|
64 |
+
|
65 |
+
# Events, webinars, and community
|
66 |
+
("Subject: Exclusive Masterclass Invitation - Limited Seats Available\n\nJoin renowned industry expert Dr. Sarah Johnson for an exclusive masterclass on 'Advanced Digital Strategy for 2025.' Only 50 seats available for this interactive session with live Q&A. Opt out of masterclass invitations here.", 1),
|
67 |
+
("Subject: Local Community Meetup - Connect in Your City\n\nWe're hosting an exclusive meetup in your city next month! Network with like-minded professionals, enjoy complimentary refreshments, and get insider access to upcoming product launches. Stop local community event notifications here.", 1),
|
68 |
+
("Subject: Annual Conference Early Bird - Save 45% on Registration\n\nSecure your spot at this year's premier industry conference with super early bird pricing. Three days of cutting-edge workshops, keynote speakers, and unparalleled networking opportunities. Unsubscribe from conference promotion emails here.", 1),
|
69 |
+
|
70 |
+
# Surveys, feedback, and reviews
|
71 |
+
("Subject: Your Opinion Matters - Quick 3-Minute Survey\n\nHelp us improve your experience with a quick 3-minute survey about our recent website updates. Your feedback directly influences our product development and customer service improvements. Opt out of customer survey requests here.", 1),
|
72 |
+
("Subject: Share Your Experience - Product Review Request\n\nYou recently purchased our bestselling wireless headphones. Would you mind sharing a quick review to help other customers make informed decisions? Your honest feedback is invaluable to our community. Turn off product review requests here.", 1),
|
73 |
+
("Subject: Customer Feedback Spotlight - You're Featured!\n\nWe loved your recent review so much that we'd like to feature it in our customer spotlight newsletter! See how your feedback is helping other customers discover their perfect products. Manage customer spotlight communications here.", 1),
|
74 |
+
|
75 |
+
# Seasonal and holiday promotions
|
76 |
+
("Subject: Holiday Gift Guide 2025 - Perfect Presents for Everyone\n\nTake the guesswork out of holiday shopping with our expertly curated gift guide. Find perfect presents for everyone on your list, organized by interests, age, and budget ranges. Stop holiday marketing communications here.", 1),
|
77 |
+
("Subject: Summer Collection Preview - Beach Ready Essentials\n\nGet beach ready with our new summer collection featuring the latest swimwear, accessories, and vacation essentials. Plus early access to summer sale prices for newsletter subscribers only. Opt out of seasonal collection previews here.", 1),
|
78 |
+
("Subject: Spring Cleaning Sale - Organize Your Space & Save\n\nSpring cleaning season is here! Save big on organization solutions, storage systems, and home improvement essentials. Plus get expert tips and tricks from our home organization specialists. Unsubscribe from seasonal promotion emails here.", 1),
|
79 |
+
]
|
80 |
+
|
81 |
+
# Important emails (security, billing, personal, urgent)
|
82 |
+
important_emails = [
|
83 |
+
# Security and account management
|
84 |
+
("Subject: Critical Security Alert - Immediate Action Required\n\nWe detected unauthorized login attempts to your account from an unrecognized device in Moscow, Russia. Your account has been temporarily secured. Please verify your identity and change your password immediately to restore full access.", 0),
|
85 |
+
("Subject: Password Successfully Updated - Security Confirmation\n\nYour account password was successfully changed on May 22, 2025 at 4:30 PM EST from IP address 192.168.1.100. If you made this change, no further action is needed. If you didn't authorize this change, contact our security team immediately.", 0),
|
86 |
+
("Subject: Two-Factor Authentication Setup Required\n\nFor enhanced account security, two-factor authentication setup is now required for all accounts. Please complete the setup process within 7 days to maintain uninterrupted access to your account. This security measure protects against unauthorized access.", 0),
|
87 |
+
("Subject: Account Verification Required - Expires in 24 Hours\n\nTo comply with security regulations and keep your account active, please verify your email address by clicking the secure link below. This verification link expires in 24 hours for your protection. Unverified accounts will be temporarily suspended.", 0),
|
88 |
+
("Subject: Suspicious Activity Detected - Account Review in Progress\n\nOur fraud detection system has flagged unusual activity on your account. As a precautionary measure, certain features have been temporarily limited while we conduct a security review. Please contact support if you have questions.", 0),
|
89 |
+
|
90 |
+
# Financial, billing, and payments
|
91 |
+
("Subject: Monthly Account Statement Now Available\n\nYour detailed account statement for May 2025 is now available for review in your secure account portal. This statement includes all transactions, fees, and important account activity from the past month. Access your statement securely online.", 0),
|
92 |
+
("Subject: Payment Method Declined - Update Required Within 48 Hours\n\nYour automatic payment of $129.99 scheduled for today was declined by your bank. To avoid service interruption, please update your payment method or contact your bank within 48 hours. Update payment details securely in your account.", 0),
|
93 |
+
("Subject: Refund Approved and Processed - $287.46\n\nGood news! Your refund request for order #ORD-2025-5678 has been approved and processed. The amount of $287.46 has been credited to your original payment method and should appear within 3-5 business days depending on your bank.", 0),
|
94 |
+
("Subject: Urgent: Invoice #INV-2025-4321 Overdue - Payment Required\n\nYour invoice #INV-2025-4321 for $445.00 is now 15 days overdue. To avoid late fees and potential service suspension, please submit payment immediately. Pay securely through your account portal or contact our billing department.", 0),
|
95 |
+
("Subject: Annual Tax Documents Available for Download\n\nYour 2024 tax documents (Form 1099-MISC) are now available for download in your account's tax documents section. These forms are required for your tax filing and should be retained for your records. Download deadline: December 31, 2025.", 0),
|
96 |
+
|
97 |
+
# Orders, shipping, and logistics
|
98 |
+
("Subject: Order Confirmation #ORD-2025-9876 - Processing Started\n\nThank you for your order! Order #ORD-2025-9876 totaling $156.78 has been received and payment successfully processed. Your items are now being prepared for shipment. You'll receive tracking information once dispatched within 1-2 business days.", 0),
|
99 |
+
("Subject: Shipment Notification - Your Order is on the Way\n\nExcellent news! Your order #ORD-2025-9876 has been shipped via UPS Ground and is currently in transit. Track your package using number 1Z999AA1012345675. Estimated delivery: May 24-26, 2025 between 9 AM and 7 PM.", 0),
|
100 |
+
("Subject: Delivery Failed - Recipient Not Available\n\nUPS attempted delivery of your package today at 3:15 PM but no one was available to receive it. Your package is now at the local UPS Customer Center. Please arrange redelivery or pickup within 5 business days to avoid return to sender.", 0),
|
101 |
+
("Subject: Package Delivered Successfully - Confirmation Required\n\nYour package was delivered today at 11:30 AM and left at your front door as per delivery instructions. If you haven't received your package or notice any damage, please contact customer service within 48 hours for immediate assistance.", 0),
|
102 |
+
("Subject: Return Authorization Approved - Instructions Included\n\nYour return request for order #ORD-2025-9876 has been approved. Use the prepaid return label attached to this email. Package items securely and drop off at any UPS location. Refund will be processed within 5-7 business days upon receipt.", 0),
|
103 |
+
|
104 |
+
# Service notifications and updates
|
105 |
+
("Subject: Scheduled System Maintenance - Service Interruption Notice\n\nImportant: We will be performing critical system maintenance on May 25, 2025 from 2:00 AM to 6:00 AM EST. During this time, our website and mobile app will be temporarily unavailable. We apologize for any inconvenience this may cause.", 0),
|
106 |
+
("Subject: Service Fully Restored - All Systems Operational\n\nOur system maintenance has been completed successfully and all services are now fully operational. Thank you for your patience during the temporary service interruption. If you experience any issues, please contact our technical support team.", 0),
|
107 |
+
("Subject: Critical: Terms of Service Update - Action Required\n\nImportant changes to our Terms of Service and Privacy Policy will take effect on June 15, 2025. Please review these updates carefully as continued use of our services after this date constitutes acceptance of the new terms.", 0),
|
108 |
+
("Subject: Subscription Expiration Warning - Renew Within 5 Days\n\nYour premium subscription expires on May 30, 2025. To continue enjoying uninterrupted access to premium features and priority support, please renew your subscription within the next 5 days. Renew now to avoid service interruption.", 0),
|
109 |
+
("Subject: Account Upgrade Successful - Welcome to Premium\n\nWelcome to Premium! Your account upgrade is now active and you have full access to all premium features including priority customer support, advanced analytics, and exclusive content. Explore your new benefits in your account dashboard.", 0),
|
110 |
+
|
111 |
+
# Personal, professional, and health
|
112 |
+
("Subject: Urgent: Medical Appointment Reminder - Tomorrow 9:30 AM\n\nReminder: You have a scheduled appointment with Dr. Jennifer Martinez tomorrow, May 23, 2025 at 9:30 AM. Location: Downtown Medical Center, Suite 402. Please arrive 15 minutes early and bring your insurance card and photo ID.", 0),
|
113 |
+
("Subject: Lab Results Available - Please Review\n\nYour recent laboratory test results are now available in your patient portal. Please log in to review your results and schedule a follow-up appointment if recommended by your healthcare provider. Results were processed on May 22, 2025.", 0),
|
114 |
+
("Subject: Emergency Contact Update Required - HR Notice\n\nOur HR records show that your emergency contact information was last updated over two years ago. Please update your emergency contacts in the employee portal within 30 days to ensure we can reach someone in case of workplace emergencies.", 0),
|
115 |
+
("Subject: Project Deadline Reminder - Report Due Tomorrow\n\nReminder: The quarterly project report for the Johnson account is due tomorrow, May 23, 2025 by 5:00 PM. Please submit your completed report to the project management system and copy all stakeholders. Contact me if you need an extension.", 0),
|
116 |
+
("Subject: Contract Renewal Notice - Legal Review Required\n\nYour service contract #CON-2025-789 expires on June 30, 2025. Please review the attached renewal terms and return the signed agreement by June 15, 2025. Contact our legal department if you have questions about the updated terms and conditions.", 0),
|
117 |
+
|
118 |
+
# Family, personal, and community
|
119 |
+
("Subject: Family Emergency - Please Call Home Immediately\n\nFamily emergency - please call home as soon as you receive this message. Mom is at St. Mary's Hospital, room 304. She's stable but asking for you. Call Dad's cell phone 555-0123 for more details. Drive safely.", 0),
|
120 |
+
("Subject: School Notification - Parent Conference Scheduled\n\nThis is to inform you that a parent-teacher conference has been scheduled for your child, Emma Johnson, on May 28, 2025 at 3:00 PM. Please contact the school office at 555-0187 if you cannot attend at the scheduled time.", 0),
|
121 |
+
("Subject: Prescription Ready for Pickup - Pharmacy Notice\n\nYour prescription for Lisinopril 10mg is ready for pickup at CVS Pharmacy, 123 Main Street. Pharmacy hours: Monday-Friday 8 AM-10 PM, Saturday-Sunday 9 AM-8 PM. Please bring photo ID when picking up your medication.", 0),
|
122 |
+
]
|
123 |
+
|
124 |
+
# Combine datasets and ensure balance
|
125 |
+
all_examples = unsubscribable_emails + important_emails
|
126 |
+
random.shuffle(all_examples)
|
127 |
+
|
128 |
+
task_logger.info(f"Created fallback dataset with {len(all_examples)} examples")
|
129 |
+
task_logger.info(f"Unsubscribable examples: {len(unsubscribable_emails)}")
|
130 |
+
task_logger.info(f"Important examples: {len(important_emails)}")
|
131 |
+
|
132 |
+
return all_examples
|
133 |
+
|
134 |
+
|
135 |
+
def download_and_extract_dataset(
|
136 |
+
dataset_key: str,
|
137 |
+
info: Dict[str, Any],
|
138 |
+
task_logger: AiTaskLogger
|
139 |
+
) -> Tuple[bool, str]:
|
140 |
+
"""
|
141 |
+
Download and extract a dataset archive with robust error handling.
|
142 |
+
"""
|
143 |
+
utils.ensure_directory_exists(config.RAW_DATASETS_DIR)
|
144 |
+
utils.ensure_directory_exists(config.EXTRACTED_DATASETS_DIR)
|
145 |
+
|
146 |
+
url = info["url"]
|
147 |
+
extract_folder_name = info["extract_folder_name"]
|
148 |
+
extracted_dir = os.path.join(config.EXTRACTED_DATASETS_DIR, extract_folder_name)
|
149 |
+
|
150 |
+
# Skip if already exists
|
151 |
+
if os.path.exists(extracted_dir) and os.listdir(extracted_dir):
|
152 |
+
task_logger.info(f"Dataset {dataset_key} already exists, skipping download.")
|
153 |
+
return True, extracted_dir
|
154 |
+
|
155 |
+
archive_path = os.path.join(config.RAW_DATASETS_DIR, f"{dataset_key}.tar.bz2")
|
156 |
+
|
157 |
+
try:
|
158 |
+
task_logger.info(f"Downloading dataset {dataset_key} from {url}")
|
159 |
+
|
160 |
+
# Download with timeout and error handling
|
161 |
+
response = requests.get(url, stream=True, timeout=30)
|
162 |
+
response.raise_for_status()
|
163 |
+
|
164 |
+
total_size = int(response.headers.get('content-length', 0))
|
165 |
+
if total_size == 0:
|
166 |
+
task_logger.warning(f"Unknown file size for {dataset_key}")
|
167 |
+
|
168 |
+
downloaded = 0
|
169 |
+
with open(archive_path, 'wb') as f:
|
170 |
+
for chunk in response.iter_content(chunk_size=8192):
|
171 |
+
if chunk:
|
172 |
+
f.write(chunk)
|
173 |
+
downloaded += len(chunk)
|
174 |
+
if total_size > 0:
|
175 |
+
progress = downloaded / total_size
|
176 |
+
task_logger.update_progress(
|
177 |
+
progress,
|
178 |
+
f"Downloading {dataset_key}: {downloaded/1024/1024:.1f} MB"
|
179 |
+
)
|
180 |
+
|
181 |
+
# Verify download
|
182 |
+
if downloaded == 0:
|
183 |
+
raise ValueError("Downloaded file is empty")
|
184 |
+
|
185 |
+
task_logger.info(f"Extracting dataset {dataset_key}")
|
186 |
+
utils.ensure_directory_exists(extracted_dir)
|
187 |
+
|
188 |
+
# Extract with error handling
|
189 |
+
with tarfile.open(archive_path, 'r:bz2') as tar:
|
190 |
+
tar.extractall(path=extracted_dir)
|
191 |
+
|
192 |
+
# Verify extraction
|
193 |
+
if not os.listdir(extracted_dir):
|
194 |
+
raise ValueError("Extracted directory is empty")
|
195 |
+
|
196 |
+
task_logger.info(f"Successfully downloaded and extracted {dataset_key}")
|
197 |
+
return True, extracted_dir
|
198 |
+
|
199 |
+
except Exception as e:
|
200 |
+
task_logger.error(f"Error downloading or extracting dataset {dataset_key}: {str(e)}")
|
201 |
+
|
202 |
+
# Cleanup failed downloads
|
203 |
+
for path in [archive_path, extracted_dir]:
|
204 |
+
if os.path.exists(path):
|
205 |
+
try:
|
206 |
+
if os.path.isdir(path):
|
207 |
+
shutil.rmtree(path)
|
208 |
+
else:
|
209 |
+
os.remove(path)
|
210 |
+
except:
|
211 |
+
pass
|
212 |
+
|
213 |
+
return False, ""
|
214 |
+
|
215 |
+
|
216 |
+
def process_email_content(email_text: str, expected_label_type: str) -> Optional[Tuple[str, int]]:
|
217 |
+
"""
|
218 |
+
Process a single email and return cleaned text with label.
|
219 |
+
"""
|
220 |
+
try:
|
221 |
+
# Parse email
|
222 |
+
msg = email.message_from_string(email_text, policy=email.policy.default)
|
223 |
+
|
224 |
+
# Extract subject and body
|
225 |
+
subject = msg.get('Subject', '').strip()
|
226 |
+
body = ""
|
227 |
+
|
228 |
+
if msg.is_multipart():
|
229 |
+
for part in msg.walk():
|
230 |
+
if part.get_content_type() == "text/plain":
|
231 |
+
body = part.get_content()
|
232 |
+
break
|
233 |
+
else:
|
234 |
+
if msg.get_content_type() == "text/plain":
|
235 |
+
body = msg.get_content()
|
236 |
+
|
237 |
+
if isinstance(body, bytes):
|
238 |
+
body = body.decode('utf-8', errors='ignore')
|
239 |
+
|
240 |
+
# Clean and combine
|
241 |
+
cleaned_subject = utils.clean_text_for_model(subject, max_length=200)
|
242 |
+
cleaned_body = utils.clean_text_for_model(body, max_length=800)
|
243 |
+
|
244 |
+
combined_text = f"Subject: {cleaned_subject}\n\n{cleaned_body}"
|
245 |
+
|
246 |
+
# Skip if too short
|
247 |
+
if len(combined_text.strip()) < config.MIN_TEXT_LENGTH_FOR_TRAINING:
|
248 |
+
return None
|
249 |
+
|
250 |
+
# Assign label based on heuristics
|
251 |
+
label = 1 if expected_label_type == "unsubscribable_leaning" else 0
|
252 |
+
|
253 |
+
# Apply some heuristics to improve labeling
|
254 |
+
text_lower = combined_text.lower()
|
255 |
+
|
256 |
+
# Marketing indicators (more likely unsubscribable)
|
257 |
+
marketing_indicators = [
|
258 |
+
'unsubscribe', 'opt out', 'newsletter', 'promotional', 'sale', 'offer',
|
259 |
+
'deal', 'discount', 'marketing', 'advertisement', 'subscribe'
|
260 |
+
]
|
261 |
+
|
262 |
+
# Important indicators (more likely important)
|
263 |
+
important_indicators = [
|
264 |
+
'urgent', 'security', 'password', 'account', 'payment', 'bill',
|
265 |
+
'order', 'shipping', 'delivered', 'confirmation', 'receipt'
|
266 |
+
]
|
267 |
+
|
268 |
+
marketing_score = sum(1 for indicator in marketing_indicators if indicator in text_lower)
|
269 |
+
important_score = sum(1 for indicator in important_indicators if indicator in text_lower)
|
270 |
+
|
271 |
+
# Adjust label based on content analysis
|
272 |
+
if marketing_score > important_score + 1:
|
273 |
+
label = 1 # Unsubscribable
|
274 |
+
elif important_score > marketing_score + 1:
|
275 |
+
label = 0 # Important
|
276 |
+
|
277 |
+
return (combined_text, label)
|
278 |
+
|
279 |
+
except Exception as e:
|
280 |
+
return None
|
281 |
+
|
282 |
+
|
283 |
+
def download_accessible_datasets(task_logger: AiTaskLogger) -> List[Tuple[str, int]]:
|
284 |
+
"""
|
285 |
+
Download and process accessible public datasets from reliable sources.
|
286 |
+
|
287 |
+
Based on 2024 research, we'll use the most accessible and reliable datasets:
|
288 |
+
1. UCI Spambase patterns (synthetic but research-based)
|
289 |
+
2. Comprehensive promotional email patterns
|
290 |
+
3. Modern phishing and security email patterns
|
291 |
+
"""
|
292 |
+
task_logger.info("Downloading accessible public datasets...")
|
293 |
+
|
294 |
+
accessible_examples = []
|
295 |
+
|
296 |
+
# UCI Spambase-inspired examples with high-frequency spam words
|
297 |
+
task_logger.info("Creating UCI Spambase-inspired examples...")
|
298 |
+
uci_examples = [
|
299 |
+
("Subject: FREE MONEY!!! Make $$$ from HOME!!!\n\nCREDIT problems? NO PROBLEM! Our BUSINESS opportunity will REMOVE all your financial worries! RECEIVE MONEY via INTERNET! MAIL us for FREE REPORT! PEOPLE OVER the world are making MONEY with OUR TECHNOLOGY! ORDER now!", 1),
|
300 |
+
("Subject: URGENT BUSINESS PROPOSAL - FREE MONEY\n\nDear friend, I am writing to seek your assistance in a business proposal involving the transfer of $15 MILLION. This is 100% RISK FREE and will bring you substantial financial reward. Please provide your email and personal details to receive full details.", 1),
|
301 |
+
("Subject: You've WON $50,000!!! CLAIM NOW!!!\n\nCongratulations! Your email address has been selected in our random drawing! You've won FIFTY THOUSAND DOLLARS! To claim your prize, simply click here and provide your personal information. This offer expires in 24 hours!", 1),
|
302 |
+
("Subject: Make $5000 per week from HOME - GUARANTEED!\n\nJoin thousands of people who are making serious money online! Our proven system will help you receive payments directly to your account! No experience needed! Free training included! Order our business package today!", 1),
|
303 |
+
("Subject: REMOVE BAD CREDIT - GUARANTEED RESULTS!\n\nOur credit repair service will remove negative items from your credit report guaranteed! People with bad credit can now get approved for loans and credit cards! Don't wait - order our service today!", 1),
|
304 |
+
("Subject: FREE Prescription Drugs - HUGE SAVINGS!\n\nSave up to 80% on all prescription medications! No prescription required! Order online and receive free shipping! Thousands of satisfied customers worldwide! Viagra, Cialis, and more available now!", 1),
|
305 |
+
("Subject: CLICK HERE for FREE ADULT CONTENT!!!\n\nHot singles in your area are waiting to meet you! Click here for free access to thousands of adult videos and photos! No credit card required! Join millions of satisfied members today!", 1),
|
306 |
+
]
|
307 |
+
|
308 |
+
# Modern phishing and promotional patterns (2024 style)
|
309 |
+
task_logger.info("Creating modern promotional patterns...")
|
310 |
+
modern_promotional = [
|
311 |
+
("Subject: 🚨 Black Friday Preview - 80% OFF EVERYTHING\n\nGet exclusive early access to our Black Friday deals! Over 50,000 items at up to 80% off. Plus free shipping on all orders. This preview is only available to our VIP subscribers. Shop now before deals expire! Unsubscribe here.", 1),
|
312 |
+
("Subject: Your Amazon Order Needs Action - Verify Now\n\nThere's an issue with your recent Amazon order #AMZ-12345. Your payment method was declined and your order will be cancelled unless you update your payment information within 24 hours. Click here to verify your account and complete your order.", 1),
|
313 |
+
("Subject: Netflix - Your Account Will Be Suspended\n\nWe're having trouble with your current billing information. To keep your Netflix membership active, please update your payment details within 48 hours. Click here to update your account and continue enjoying Netflix.", 1),
|
314 |
+
("Subject: Apple ID Security Alert - Suspicious Activity Detected\n\nWe've detected unusual activity on your Apple ID account. For your security, we've temporarily disabled your account. To restore access, please verify your identity by clicking the link below and confirming your information.", 1),
|
315 |
+
("Subject: PayPal - Action Required on Your Account\n\nWe've noticed some unusual activity on your PayPal account. To protect your account, we've temporarily limited your access. Please log in and verify your identity to restore full account functionality.", 1),
|
316 |
+
("Subject: Microsoft Office - Your Subscription is Expiring\n\nYour Microsoft Office subscription expires tomorrow. To continue using Word, Excel, PowerPoint and other Office apps, please renew your subscription. Click here to renew and save 50% on your next year.", 1),
|
317 |
+
]
|
318 |
+
|
319 |
+
# Security and important email patterns (realistic)
|
320 |
+
task_logger.info("Creating important email patterns...")
|
321 |
+
important_patterns = [
|
322 |
+
("Subject: Password Reset Request for Your Account\n\nWe received a request to reset the password for your account. If you made this request, click the link below to reset your password. If you didn't request this, please ignore this email and your password will remain unchanged.", 0),
|
323 |
+
("Subject: Your Order Has Shipped - Tracking Information\n\nGood news! Your order #12345 has been shipped and is on its way to you. You can track your package using the tracking number 1Z999AA1234567890. Expected delivery date is May 25-27, 2025.", 0),
|
324 |
+
("Subject: Meeting Reminder - Project Review Tomorrow\n\nThis is a reminder that we have our project review meeting scheduled for tomorrow, May 23rd at 2:00 PM in Conference Room B. Please bring your project updates and quarterly reports. Contact me if you can't attend.", 0),
|
325 |
+
("Subject: Bank Statement Ready - May 2025\n\nYour monthly bank statement for May 2025 is now available in your online banking portal. Please review your transactions and contact us if you notice any discrepancies. Thank you for banking with us.", 0),
|
326 |
+
("Subject: Appointment Confirmation - Dr. Johnson\n\nThis confirms your appointment with Dr. Johnson on Friday, May 24th at 10:30 AM. Please arrive 15 minutes early and bring your insurance card and a valid ID. Call us if you need to reschedule.", 0),
|
327 |
+
("Subject: Flight Confirmation - Your Trip Details\n\nYour flight is confirmed! Flight AA1234 from Chicago to New York on May 25th at 8:30 AM. Please arrive at the airport 2 hours before departure. Check-in is now available online or through our mobile app.", 0),
|
328 |
+
("Subject: Invoice #INV-2025-5678 - Payment Due\n\nYour invoice #INV-2025-5678 for $250.00 is due on May 30th. You can pay online through our customer portal or by mailing a check to our office. Please contact us if you have any questions about this invoice.", 0),
|
329 |
+
]
|
330 |
+
|
331 |
+
# Combine all examples
|
332 |
+
accessible_examples.extend(uci_examples)
|
333 |
+
accessible_examples.extend(modern_promotional)
|
334 |
+
accessible_examples.extend(important_patterns)
|
335 |
+
|
336 |
+
task_logger.info(f"Created {len(accessible_examples)} examples from accessible datasets")
|
337 |
+
return accessible_examples
|
338 |
+
|
339 |
+
def ml_suite_process_public_datasets(task_logger: AiTaskLogger) -> bool:
|
340 |
+
"""
|
341 |
+
Main function to process public datasets with robust fallback and accessible downloads.
|
342 |
+
"""
|
343 |
+
task_logger.info("Starting enhanced public dataset preparation with accessible sources")
|
344 |
+
|
345 |
+
try:
|
346 |
+
# Ensure directories exist
|
347 |
+
utils.ensure_directory_exists(config.PROCESSED_DATASETS_DIR)
|
348 |
+
|
349 |
+
all_training_examples = []
|
350 |
+
successful_datasets = 0
|
351 |
+
|
352 |
+
# First, try to get accessible datasets
|
353 |
+
task_logger.info("Attempting to use accessible public datasets...")
|
354 |
+
try:
|
355 |
+
accessible_examples = download_accessible_datasets(task_logger)
|
356 |
+
all_training_examples.extend(accessible_examples)
|
357 |
+
successful_datasets += 1
|
358 |
+
task_logger.info(f"Successfully loaded {len(accessible_examples)} examples from accessible sources")
|
359 |
+
except Exception as e:
|
360 |
+
task_logger.warning(f"Failed to load accessible datasets: {e}")
|
361 |
+
|
362 |
+
# Try to process external datasets (SpamAssassin, etc.)
|
363 |
+
for dataset_key, dataset_info in config.PUBLIC_DATASETS_INFO.items():
|
364 |
+
task_logger.info(f"Processing external dataset {dataset_key}")
|
365 |
+
|
366 |
+
try:
|
367 |
+
success, extracted_dir = download_and_extract_dataset(
|
368 |
+
dataset_key, dataset_info, task_logger
|
369 |
+
)
|
370 |
+
|
371 |
+
if success:
|
372 |
+
# Process emails from extracted dataset
|
373 |
+
email_files = []
|
374 |
+
for root, dirs, files in os.walk(extracted_dir):
|
375 |
+
for file in files:
|
376 |
+
if not file.startswith('.'):
|
377 |
+
email_files.append(os.path.join(root, file))
|
378 |
+
|
379 |
+
processed_from_dataset = 0
|
380 |
+
for email_file in email_files[:config.MAX_SAMPLES_PER_RAW_DATASET]:
|
381 |
+
try:
|
382 |
+
with open(email_file, 'r', encoding='utf-8', errors='ignore') as f:
|
383 |
+
email_content = f.read()
|
384 |
+
|
385 |
+
processed = process_email_content(
|
386 |
+
email_content,
|
387 |
+
dataset_info.get('type', 'important_leaning')
|
388 |
+
)
|
389 |
+
|
390 |
+
if processed:
|
391 |
+
all_training_examples.append(processed)
|
392 |
+
processed_from_dataset += 1
|
393 |
+
|
394 |
+
except Exception as e:
|
395 |
+
continue
|
396 |
+
|
397 |
+
task_logger.info(f"Processed {processed_from_dataset} emails from {dataset_key}")
|
398 |
+
if processed_from_dataset > 0:
|
399 |
+
successful_datasets += 1
|
400 |
+
|
401 |
+
except Exception as e:
|
402 |
+
task_logger.error(f"Failed to process external dataset {dataset_key}: {str(e)}")
|
403 |
+
continue
|
404 |
+
|
405 |
+
# If we still don't have enough data, use comprehensive fallback
|
406 |
+
if len(all_training_examples) < 50:
|
407 |
+
task_logger.warning("Insufficient data from external sources, using comprehensive fallback dataset")
|
408 |
+
fallback_examples = create_comprehensive_fallback_dataset(task_logger)
|
409 |
+
all_training_examples.extend(fallback_examples)
|
410 |
+
|
411 |
+
# Shuffle and balance the dataset
|
412 |
+
random.shuffle(all_training_examples)
|
413 |
+
|
414 |
+
# Write to CSV
|
415 |
+
with open(config.PREPARED_DATA_FILE, 'w', newline='', encoding='utf-8') as csvfile:
|
416 |
+
writer = csv.writer(csvfile)
|
417 |
+
writer.writerow(['text', 'label'])
|
418 |
+
for text, label in all_training_examples:
|
419 |
+
writer.writerow([text, label])
|
420 |
+
|
421 |
+
# Log final statistics
|
422 |
+
label_counts = Counter(label for _, label in all_training_examples)
|
423 |
+
task_logger.info(f"Dataset preparation completed successfully!")
|
424 |
+
task_logger.info(f"Total samples: {len(all_training_examples)}")
|
425 |
+
task_logger.info(f"Unsubscribable (1): {label_counts.get(1, 0)}")
|
426 |
+
task_logger.info(f"Important (0): {label_counts.get(0, 0)}")
|
427 |
+
task_logger.info(f"Successful external datasets: {successful_datasets}")
|
428 |
+
|
429 |
+
return True
|
430 |
+
|
431 |
+
except Exception as e:
|
432 |
+
task_logger.error(f"Critical error in dataset preparation: {str(e)}")
|
433 |
+
|
434 |
+
# Last resort fallback
|
435 |
+
try:
|
436 |
+
task_logger.info("Attempting emergency fallback dataset creation...")
|
437 |
+
fallback_examples = create_comprehensive_fallback_dataset(task_logger)
|
438 |
+
|
439 |
+
with open(config.PREPARED_DATA_FILE, 'w', newline='', encoding='utf-8') as csvfile:
|
440 |
+
writer = csv.writer(csvfile)
|
441 |
+
writer.writerow(['text', 'label'])
|
442 |
+
for text, label in fallback_examples:
|
443 |
+
writer.writerow([text, label])
|
444 |
+
|
445 |
+
task_logger.info("Emergency fallback dataset created successfully")
|
446 |
+
return True
|
447 |
+
|
448 |
+
except Exception as fallback_error:
|
449 |
+
task_logger.error(f"Emergency fallback also failed: {str(fallback_error)}")
|
450 |
+
return False
|
451 |
+
|
452 |
+
|
453 |
+
# Export the main function for use by the task system
|
454 |
+
def prepare_training_data_from_public_datasets(task_logger: AiTaskLogger) -> bool:
|
455 |
+
"""
|
456 |
+
Main entry point for data preparation with comprehensive error handling.
|
457 |
+
"""
|
458 |
+
return ml_suite_process_public_datasets(task_logger)
|
ml_suite/model_trainer.py
ADDED
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Model trainer module for the Gmail Unsubscriber AI suite.
|
3 |
+
|
4 |
+
This module is responsible for:
|
5 |
+
- Loading prepared email data
|
6 |
+
- Splitting data into training and evaluation sets
|
7 |
+
- Loading pre-trained transformer model
|
8 |
+
- Fine-tuning the model on the email dataset
|
9 |
+
- Evaluating model performance
|
10 |
+
- Saving the fine-tuned model for prediction
|
11 |
+
|
12 |
+
The trained model is optimized for classifying emails as "important" or "unsubscribable".
|
13 |
+
"""
|
14 |
+
|
15 |
+
import os
|
16 |
+
import time
|
17 |
+
import pandas as pd
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
from typing import Dict, List, Tuple, Optional, Any, Union, Callable
|
21 |
+
from sklearn.model_selection import train_test_split
|
22 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix
|
23 |
+
from sklearn.utils.class_weight import compute_class_weight
|
24 |
+
|
25 |
+
# Hugging Face imports
|
26 |
+
from transformers import (
|
27 |
+
AutoTokenizer,
|
28 |
+
AutoModelForSequenceClassification,
|
29 |
+
Trainer,
|
30 |
+
TrainingArguments,
|
31 |
+
EarlyStoppingCallback,
|
32 |
+
IntervalStrategy,
|
33 |
+
PreTrainedTokenizer,
|
34 |
+
PreTrainedModel
|
35 |
+
)
|
36 |
+
from datasets import Dataset
|
37 |
+
import torch.nn as nn
|
38 |
+
|
39 |
+
# Local imports
|
40 |
+
from . import config
|
41 |
+
from . import utils
|
42 |
+
from .task_utils import AiTaskLogger
|
43 |
+
|
44 |
+
|
45 |
+
def compute_metrics_for_classification(eval_pred) -> Dict[str, float]:
|
46 |
+
"""
|
47 |
+
Compute evaluation metrics for the model.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
eval_pred: Tuple of predictions and labels from the trainer
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
Dictionary of metrics including accuracy, precision, recall, F1, etc.
|
54 |
+
"""
|
55 |
+
predictions, labels = eval_pred
|
56 |
+
|
57 |
+
# For classification, take the argmax to get predicted classes
|
58 |
+
preds = np.argmax(predictions, axis=1)
|
59 |
+
|
60 |
+
# Calculate overall metrics
|
61 |
+
precision, recall, f1, _ = precision_recall_fscore_support(
|
62 |
+
labels, preds, average='weighted'
|
63 |
+
)
|
64 |
+
acc = accuracy_score(labels, preds)
|
65 |
+
|
66 |
+
# Calculate metrics specific to the "unsubscribable" class
|
67 |
+
# Get label positions: 0 for important, 1 for unsubscribable
|
68 |
+
unsub_class_idx = config.LABEL_UNSUBSCRIBABLE_ID
|
69 |
+
|
70 |
+
# Calculate class-specific precision, recall, F1
|
71 |
+
precision_unsub, recall_unsub, f1_unsub, _ = precision_recall_fscore_support(
|
72 |
+
labels, preds, average=None, labels=[unsub_class_idx]
|
73 |
+
)
|
74 |
+
|
75 |
+
# Compute ROC AUC if possible (requires probability scores)
|
76 |
+
try:
|
77 |
+
# Use the probability score for the unsubscribable class
|
78 |
+
probs_unsub = predictions[:, unsub_class_idx]
|
79 |
+
|
80 |
+
# Convert labels to binary (1 for unsubscribable, 0 for others)
|
81 |
+
binary_labels = (labels == unsub_class_idx).astype(int)
|
82 |
+
|
83 |
+
# Calculate ROC AUC
|
84 |
+
roc_auc = roc_auc_score(binary_labels, probs_unsub)
|
85 |
+
except (ValueError, IndexError):
|
86 |
+
roc_auc = 0.0
|
87 |
+
|
88 |
+
# Return all metrics in a dictionary
|
89 |
+
metrics = {
|
90 |
+
'accuracy': acc,
|
91 |
+
'precision': precision,
|
92 |
+
'recall': recall,
|
93 |
+
'f1': f1,
|
94 |
+
'precision_unsub': float(precision_unsub[0]) if len(precision_unsub) > 0 else 0.0,
|
95 |
+
'recall_unsub': float(recall_unsub[0]) if len(recall_unsub) > 0 else 0.0,
|
96 |
+
'f1_unsub': float(f1_unsub[0]) if len(f1_unsub) > 0 else 0.0,
|
97 |
+
'roc_auc': roc_auc
|
98 |
+
}
|
99 |
+
|
100 |
+
return metrics
|
101 |
+
|
102 |
+
|
103 |
+
class WeightedLossTrainer(Trainer):
|
104 |
+
"""Custom trainer that supports class weights for imbalanced datasets."""
|
105 |
+
|
106 |
+
def __init__(self, class_weights=None, label_smoothing=0.0, *args, **kwargs):
|
107 |
+
super().__init__(*args, **kwargs)
|
108 |
+
self.class_weights = class_weights
|
109 |
+
self.label_smoothing = label_smoothing
|
110 |
+
|
111 |
+
def compute_loss(self, model, inputs, return_outputs=False):
|
112 |
+
"""Compute loss with class weights and optional label smoothing."""
|
113 |
+
labels = inputs.pop("labels")
|
114 |
+
outputs = model(**inputs)
|
115 |
+
logits = outputs.get('logits')
|
116 |
+
|
117 |
+
# Create loss function with class weights
|
118 |
+
if self.class_weights is not None:
|
119 |
+
weight = torch.tensor(self.class_weights, dtype=torch.float32).to(logits.device)
|
120 |
+
loss_fct = nn.CrossEntropyLoss(weight=weight, label_smoothing=self.label_smoothing)
|
121 |
+
else:
|
122 |
+
loss_fct = nn.CrossEntropyLoss(label_smoothing=self.label_smoothing)
|
123 |
+
|
124 |
+
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
|
125 |
+
|
126 |
+
return (loss, outputs) if return_outputs else loss
|
127 |
+
|
128 |
+
|
129 |
+
def tokenize_dataset(
|
130 |
+
examples: Dict[str, List],
|
131 |
+
tokenizer: PreTrainedTokenizer,
|
132 |
+
max_length: int
|
133 |
+
) -> Dict[str, List]:
|
134 |
+
"""
|
135 |
+
Tokenize a batch of examples.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
examples: Dictionary of example lists
|
139 |
+
tokenizer: Hugging Face tokenizer
|
140 |
+
max_length: Maximum sequence length
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
Dictionary of tokenized examples
|
144 |
+
"""
|
145 |
+
return tokenizer(
|
146 |
+
examples['text'],
|
147 |
+
padding='max_length',
|
148 |
+
truncation=True,
|
149 |
+
max_length=max_length,
|
150 |
+
return_tensors='pt'
|
151 |
+
)
|
152 |
+
|
153 |
+
|
154 |
+
def prepare_datasets_for_training(
|
155 |
+
data_file: str,
|
156 |
+
test_size: float,
|
157 |
+
tokenizer: PreTrainedTokenizer,
|
158 |
+
max_length: int,
|
159 |
+
task_logger: AiTaskLogger,
|
160 |
+
compute_weights: bool = True
|
161 |
+
) -> Tuple[Dataset, Dataset, Optional[np.ndarray]]:
|
162 |
+
"""
|
163 |
+
Load and prepare datasets for training.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
data_file: Path to the prepared data file (CSV)
|
167 |
+
test_size: Proportion of data to use for evaluation
|
168 |
+
tokenizer: Hugging Face tokenizer
|
169 |
+
max_length: Maximum sequence length
|
170 |
+
task_logger: Logger for tracking progress
|
171 |
+
compute_weights: Whether to compute class weights for imbalanced data
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
Tuple of (train_dataset, eval_dataset, class_weights)
|
175 |
+
"""
|
176 |
+
task_logger.info(f"Loading data from {data_file}")
|
177 |
+
|
178 |
+
try:
|
179 |
+
# Load the dataset
|
180 |
+
df = pd.read_csv(data_file)
|
181 |
+
task_logger.info(f"Loaded {len(df)} examples from {data_file}")
|
182 |
+
|
183 |
+
# Check for required columns
|
184 |
+
if 'text' not in df.columns or 'label' not in df.columns:
|
185 |
+
task_logger.error("Data file is missing required columns 'text' and/or 'label'")
|
186 |
+
raise ValueError("Data file has wrong format")
|
187 |
+
|
188 |
+
# Split into training and evaluation sets
|
189 |
+
train_df, eval_df = train_test_split(
|
190 |
+
df, test_size=test_size, stratify=df['label'], random_state=42
|
191 |
+
)
|
192 |
+
|
193 |
+
task_logger.info(f"Split into {len(train_df)} training examples and {len(eval_df)} evaluation examples")
|
194 |
+
|
195 |
+
# Compute class weights for imbalanced data
|
196 |
+
class_weights = None
|
197 |
+
if compute_weights:
|
198 |
+
class_weights = compute_class_weight(
|
199 |
+
'balanced',
|
200 |
+
classes=np.array([0, 1]),
|
201 |
+
y=train_df['label'].values
|
202 |
+
)
|
203 |
+
task_logger.info(f"Computed class weights: {class_weights}")
|
204 |
+
|
205 |
+
# Create HF datasets
|
206 |
+
train_dataset = Dataset.from_pandas(train_df)
|
207 |
+
eval_dataset = Dataset.from_pandas(eval_df)
|
208 |
+
|
209 |
+
# Create a tokenization function that uses our tokenizer
|
210 |
+
def tokenize_function(examples):
|
211 |
+
return tokenize_dataset(examples, tokenizer, max_length)
|
212 |
+
|
213 |
+
# Apply tokenization
|
214 |
+
task_logger.info("Tokenizing datasets")
|
215 |
+
train_dataset = train_dataset.map(
|
216 |
+
tokenize_function,
|
217 |
+
batched=True,
|
218 |
+
desc="Tokenizing training dataset"
|
219 |
+
)
|
220 |
+
eval_dataset = eval_dataset.map(
|
221 |
+
tokenize_function,
|
222 |
+
batched=True,
|
223 |
+
desc="Tokenizing evaluation dataset"
|
224 |
+
)
|
225 |
+
|
226 |
+
# Set format for PyTorch
|
227 |
+
train_dataset.set_format(
|
228 |
+
type='torch',
|
229 |
+
columns=['input_ids', 'attention_mask', 'label']
|
230 |
+
)
|
231 |
+
eval_dataset.set_format(
|
232 |
+
type='torch',
|
233 |
+
columns=['input_ids', 'attention_mask', 'label']
|
234 |
+
)
|
235 |
+
|
236 |
+
task_logger.info("Datasets prepared successfully")
|
237 |
+
|
238 |
+
return train_dataset, eval_dataset, class_weights
|
239 |
+
|
240 |
+
except Exception as e:
|
241 |
+
task_logger.error(f"Error preparing datasets: {str(e)}", e)
|
242 |
+
raise
|
243 |
+
|
244 |
+
|
245 |
+
def train_unsubscriber_model(task_logger: AiTaskLogger) -> Dict[str, Any]:
|
246 |
+
"""
|
247 |
+
Train the unsubscriber model on the prepared dataset.
|
248 |
+
|
249 |
+
This function:
|
250 |
+
1. Loads prepared data
|
251 |
+
2. Initializes a pre-trained transformer model
|
252 |
+
3. Fine-tunes it on the email classification task
|
253 |
+
4. Evaluates performance
|
254 |
+
5. Saves the model for later use
|
255 |
+
|
256 |
+
Args:
|
257 |
+
task_logger: Logger for tracking task status
|
258 |
+
|
259 |
+
Returns:
|
260 |
+
Dictionary with training results and metrics
|
261 |
+
"""
|
262 |
+
# Start timing
|
263 |
+
start_time = time.time()
|
264 |
+
|
265 |
+
# Start the task
|
266 |
+
task_logger.start_task("Starting AI model training")
|
267 |
+
|
268 |
+
try:
|
269 |
+
# Check if prepared data exists
|
270 |
+
if not os.path.exists(config.PREPARED_DATA_FILE):
|
271 |
+
task_logger.error(f"Prepared data file not found at {config.PREPARED_DATA_FILE}.")
|
272 |
+
task_logger.fail_task("Training failed: No prepared data available. Please run data preparation first.")
|
273 |
+
return {"success": False, "error": "No prepared data available"}
|
274 |
+
|
275 |
+
# 1. Load and initialize the tokenizer
|
276 |
+
task_logger.info(f"Loading tokenizer for model: {config.PRE_TRAINED_MODEL_NAME}")
|
277 |
+
tokenizer = AutoTokenizer.from_pretrained(config.PRE_TRAINED_MODEL_NAME)
|
278 |
+
|
279 |
+
# 2. Prepare datasets
|
280 |
+
task_logger.update_progress(0.1, "Preparing datasets")
|
281 |
+
train_dataset, eval_dataset, class_weights = prepare_datasets_for_training(
|
282 |
+
config.PREPARED_DATA_FILE,
|
283 |
+
config.TEST_SPLIT_SIZE,
|
284 |
+
tokenizer,
|
285 |
+
config.MAX_SEQ_LENGTH,
|
286 |
+
task_logger,
|
287 |
+
compute_weights=True
|
288 |
+
)
|
289 |
+
|
290 |
+
# 3. Initialize model
|
291 |
+
task_logger.update_progress(0.2, f"Initializing model: {config.PRE_TRAINED_MODEL_NAME}")
|
292 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
293 |
+
config.PRE_TRAINED_MODEL_NAME,
|
294 |
+
num_labels=config.NUM_LABELS
|
295 |
+
)
|
296 |
+
|
297 |
+
# Check device availability
|
298 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
299 |
+
fp16_enabled = config.FP16_TRAINING and torch.cuda.is_available()
|
300 |
+
|
301 |
+
# Force GPU check and provide detailed info
|
302 |
+
if torch.cuda.is_available():
|
303 |
+
gpu_name = torch.cuda.get_device_name(0)
|
304 |
+
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) # Convert to GB
|
305 |
+
task_logger.info(f"GPU detected: {gpu_name} with {gpu_memory:.2f} GB memory")
|
306 |
+
task_logger.info(f"CUDA version: {torch.version.cuda}")
|
307 |
+
|
308 |
+
# Set CUDA device explicitly
|
309 |
+
torch.cuda.set_device(0)
|
310 |
+
|
311 |
+
# Enable cuDNN benchmarking for better performance
|
312 |
+
torch.backends.cudnn.benchmark = True
|
313 |
+
torch.backends.cudnn.enabled = True
|
314 |
+
else:
|
315 |
+
task_logger.warning("No GPU detected! Training will be slow on CPU.")
|
316 |
+
task_logger.warning("Make sure you have the CUDA version of PyTorch installed.")
|
317 |
+
task_logger.warning("To install PyTorch with CUDA support, run:")
|
318 |
+
task_logger.warning("pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118")
|
319 |
+
|
320 |
+
task_logger.info(f"Training on device: {device}")
|
321 |
+
if fp16_enabled:
|
322 |
+
task_logger.info("FP16 mixed precision training enabled")
|
323 |
+
|
324 |
+
# Move model to device
|
325 |
+
model.to(device)
|
326 |
+
|
327 |
+
# 4. Set up training arguments
|
328 |
+
task_logger.update_progress(0.3, "Setting up training configuration")
|
329 |
+
|
330 |
+
# Calculate number of steps
|
331 |
+
num_train_examples = len(train_dataset)
|
332 |
+
num_train_steps = (num_train_examples // config.TRAIN_BATCH_SIZE) * config.NUM_TRAIN_EPOCHS
|
333 |
+
warmup_steps = int(num_train_steps * config.WARMUP_STEPS_RATIO)
|
334 |
+
|
335 |
+
# Convert evaluation and save strategies to enum values
|
336 |
+
eval_strategy = (IntervalStrategy.EPOCH
|
337 |
+
if config.EVALUATION_STRATEGY == "epoch"
|
338 |
+
else IntervalStrategy.STEPS)
|
339 |
+
save_strategy = (IntervalStrategy.EPOCH
|
340 |
+
if config.SAVE_STRATEGY == "epoch"
|
341 |
+
else IntervalStrategy.STEPS)
|
342 |
+
|
343 |
+
# Create output directory if it doesn't exist
|
344 |
+
os.makedirs(config.FINE_TUNED_MODEL_DIR, exist_ok=True)
|
345 |
+
|
346 |
+
# Define training arguments (with compatibility fixes for older transformers versions)
|
347 |
+
training_args_dict = {
|
348 |
+
"output_dir": config.FINE_TUNED_MODEL_DIR,
|
349 |
+
"per_device_train_batch_size": config.TRAIN_BATCH_SIZE,
|
350 |
+
"per_device_eval_batch_size": config.EVAL_BATCH_SIZE,
|
351 |
+
"num_train_epochs": config.NUM_TRAIN_EPOCHS,
|
352 |
+
"learning_rate": config.LEARNING_RATE,
|
353 |
+
"weight_decay": config.WEIGHT_DECAY,
|
354 |
+
"warmup_steps": warmup_steps,
|
355 |
+
"logging_steps": 50,
|
356 |
+
"save_total_limit": 2,
|
357 |
+
"load_best_model_at_end": config.LOAD_BEST_MODEL_AT_END,
|
358 |
+
"metric_for_best_model": config.METRIC_FOR_BEST_MODEL,
|
359 |
+
"greater_is_better": True,
|
360 |
+
"report_to": "none",
|
361 |
+
"disable_tqdm": False,
|
362 |
+
"no_cuda": False, # Explicitly enable CUDA
|
363 |
+
"use_cpu": False, # Explicitly disable CPU-only mode
|
364 |
+
}
|
365 |
+
|
366 |
+
# Add evaluation and save strategies
|
367 |
+
try:
|
368 |
+
training_args_dict["evaluation_strategy"] = eval_strategy
|
369 |
+
training_args_dict["save_strategy"] = save_strategy
|
370 |
+
except:
|
371 |
+
# Fallback for older versions
|
372 |
+
training_args_dict["evaluation_strategy"] = "epoch"
|
373 |
+
training_args_dict["save_strategy"] = "epoch"
|
374 |
+
|
375 |
+
# Add logging dir if supported
|
376 |
+
try:
|
377 |
+
training_args_dict["logging_dir"] = os.path.join(config.FINE_TUNED_MODEL_DIR, "logs")
|
378 |
+
except:
|
379 |
+
pass
|
380 |
+
|
381 |
+
# Add FP16 if supported and available
|
382 |
+
if fp16_enabled:
|
383 |
+
try:
|
384 |
+
training_args_dict["fp16"] = True
|
385 |
+
except:
|
386 |
+
task_logger.warning("FP16 training not supported on this setup")
|
387 |
+
|
388 |
+
# Add dataloader workers if supported
|
389 |
+
try:
|
390 |
+
training_args_dict["dataloader_num_workers"] = 2
|
391 |
+
except:
|
392 |
+
pass
|
393 |
+
|
394 |
+
# Add gradient accumulation if configured
|
395 |
+
if hasattr(config, 'GRADIENT_ACCUMULATION_STEPS'):
|
396 |
+
training_args_dict["gradient_accumulation_steps"] = config.GRADIENT_ACCUMULATION_STEPS
|
397 |
+
|
398 |
+
# Add label smoothing if configured
|
399 |
+
if hasattr(config, 'LABEL_SMOOTHING_FACTOR'):
|
400 |
+
training_args_dict["label_smoothing_factor"] = config.LABEL_SMOOTHING_FACTOR
|
401 |
+
|
402 |
+
training_args = TrainingArguments(**training_args_dict)
|
403 |
+
|
404 |
+
# 5. Initialize the trainer with class weights
|
405 |
+
task_logger.update_progress(0.4, "Initializing trainer with class balancing")
|
406 |
+
|
407 |
+
# Use weighted loss trainer if we have class weights
|
408 |
+
if class_weights is not None:
|
409 |
+
trainer = WeightedLossTrainer(
|
410 |
+
model=model,
|
411 |
+
args=training_args,
|
412 |
+
train_dataset=train_dataset,
|
413 |
+
eval_dataset=eval_dataset,
|
414 |
+
compute_metrics=compute_metrics_for_classification,
|
415 |
+
class_weights=class_weights,
|
416 |
+
label_smoothing=config.LABEL_SMOOTHING_FACTOR if hasattr(config, 'LABEL_SMOOTHING_FACTOR') else 0.0,
|
417 |
+
callbacks=[
|
418 |
+
EarlyStoppingCallback(
|
419 |
+
early_stopping_patience=config.EARLY_STOPPING_PATIENCE,
|
420 |
+
early_stopping_threshold=config.EARLY_STOPPING_THRESHOLD
|
421 |
+
)
|
422 |
+
]
|
423 |
+
)
|
424 |
+
task_logger.info("Using weighted loss function for class imbalance")
|
425 |
+
else:
|
426 |
+
trainer = Trainer(
|
427 |
+
model=model,
|
428 |
+
args=training_args,
|
429 |
+
train_dataset=train_dataset,
|
430 |
+
eval_dataset=eval_dataset,
|
431 |
+
compute_metrics=compute_metrics_for_classification,
|
432 |
+
callbacks=[
|
433 |
+
EarlyStoppingCallback(
|
434 |
+
early_stopping_patience=config.EARLY_STOPPING_PATIENCE,
|
435 |
+
early_stopping_threshold=config.EARLY_STOPPING_THRESHOLD
|
436 |
+
)
|
437 |
+
]
|
438 |
+
)
|
439 |
+
|
440 |
+
# 6. Train the model
|
441 |
+
task_logger.update_progress(0.5, "Starting model training")
|
442 |
+
trainer.train()
|
443 |
+
|
444 |
+
# 7. Evaluate the model
|
445 |
+
task_logger.update_progress(0.9, "Evaluating model")
|
446 |
+
eval_results = trainer.evaluate()
|
447 |
+
|
448 |
+
# Format evaluations nicely for logging
|
449 |
+
metrics_str = "\n".join([f" {k}: {v:.4f}" for k, v in eval_results.items()])
|
450 |
+
task_logger.info(f"Evaluation results:\n{metrics_str}")
|
451 |
+
|
452 |
+
# 8. Save the final model
|
453 |
+
task_logger.update_progress(0.95, "Saving fine-tuned model")
|
454 |
+
trainer.save_model(config.FINE_TUNED_MODEL_DIR)
|
455 |
+
tokenizer.save_pretrained(config.FINE_TUNED_MODEL_DIR)
|
456 |
+
|
457 |
+
# Save a human-readable summary of model info
|
458 |
+
with open(os.path.join(config.FINE_TUNED_MODEL_DIR, "model_info.txt"), "w") as f:
|
459 |
+
f.write(f"Base model: {config.PRE_TRAINED_MODEL_NAME}\n")
|
460 |
+
f.write(f"Training completed: {utils.get_current_timestamp()}\n")
|
461 |
+
f.write(f"Training examples: {len(train_dataset)}\n")
|
462 |
+
f.write(f"Evaluation examples: {len(eval_dataset)}\n")
|
463 |
+
f.write(f"Evaluation metrics:\n")
|
464 |
+
for k, v in eval_results.items():
|
465 |
+
f.write(f" {k}: {v:.4f}\n")
|
466 |
+
|
467 |
+
# Calculate elapsed time
|
468 |
+
elapsed_time = time.time() - start_time
|
469 |
+
hours, remainder = divmod(elapsed_time, 3600)
|
470 |
+
minutes, seconds = divmod(remainder, 60)
|
471 |
+
time_str = f"{int(hours)}h {int(minutes)}m {int(seconds)}s"
|
472 |
+
|
473 |
+
# Complete the task
|
474 |
+
result = {
|
475 |
+
"success": True,
|
476 |
+
"metrics": eval_results,
|
477 |
+
"model_dir": config.FINE_TUNED_MODEL_DIR,
|
478 |
+
"training_time": time_str,
|
479 |
+
"base_model": config.PRE_TRAINED_MODEL_NAME,
|
480 |
+
"num_train_examples": len(train_dataset),
|
481 |
+
"num_eval_examples": len(eval_dataset)
|
482 |
+
}
|
483 |
+
|
484 |
+
task_logger.complete_task(
|
485 |
+
f"Model training completed successfully in {time_str}",
|
486 |
+
result
|
487 |
+
)
|
488 |
+
|
489 |
+
return result
|
490 |
+
|
491 |
+
except Exception as e:
|
492 |
+
task_logger.error("Error during model training", e)
|
493 |
+
task_logger.fail_task(f"Model training failed: {str(e)}")
|
494 |
+
|
495 |
+
return {
|
496 |
+
"success": False,
|
497 |
+
"error": str(e)
|
498 |
+
}
|
ml_suite/models/base_transformer_cache/version.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
1
|
ml_suite/models/fine_tuned_unsubscriber/config.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "distilbert-base-uncased",
|
3 |
+
"activation": "gelu",
|
4 |
+
"architectures": [
|
5 |
+
"DistilBertForSequenceClassification"
|
6 |
+
],
|
7 |
+
"attention_dropout": 0.1,
|
8 |
+
"dim": 768,
|
9 |
+
"dropout": 0.1,
|
10 |
+
"hidden_dim": 3072,
|
11 |
+
"initializer_range": 0.02,
|
12 |
+
"max_position_embeddings": 512,
|
13 |
+
"model_type": "distilbert",
|
14 |
+
"n_heads": 12,
|
15 |
+
"n_layers": 6,
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"qa_dropout": 0.1,
|
18 |
+
"seq_classif_dropout": 0.2,
|
19 |
+
"sinusoidal_pos_embds": false,
|
20 |
+
"tie_weights_": true,
|
21 |
+
"torch_dtype": "float32",
|
22 |
+
"transformers_version": "4.36.2",
|
23 |
+
"vocab_size": 30522
|
24 |
+
}
|
ml_suite/models/fine_tuned_unsubscriber/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4bdf7a6a37778c581baa94c0a6c462836894e5cc009a3e077e96e1c0b9ec176e
|
3 |
+
size 267832560
|
ml_suite/models/fine_tuned_unsubscriber/special_tokens_map.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"mask_token": "[MASK]",
|
4 |
+
"pad_token": "[PAD]",
|
5 |
+
"sep_token": "[SEP]",
|
6 |
+
"unk_token": "[UNK]"
|
7 |
+
}
|
ml_suite/models/fine_tuned_unsubscriber/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
ml_suite/models/fine_tuned_unsubscriber/tokenizer_config.json
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"0": {
|
4 |
+
"content": "[PAD]",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"100": {
|
12 |
+
"content": "[UNK]",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"101": {
|
20 |
+
"content": "[CLS]",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
},
|
27 |
+
"102": {
|
28 |
+
"content": "[SEP]",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": false,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false,
|
33 |
+
"special": true
|
34 |
+
},
|
35 |
+
"103": {
|
36 |
+
"content": "[MASK]",
|
37 |
+
"lstrip": false,
|
38 |
+
"normalized": false,
|
39 |
+
"rstrip": false,
|
40 |
+
"single_word": false,
|
41 |
+
"special": true
|
42 |
+
}
|
43 |
+
},
|
44 |
+
"clean_up_tokenization_spaces": true,
|
45 |
+
"cls_token": "[CLS]",
|
46 |
+
"do_lower_case": true,
|
47 |
+
"mask_token": "[MASK]",
|
48 |
+
"model_max_length": 512,
|
49 |
+
"pad_token": "[PAD]",
|
50 |
+
"sep_token": "[SEP]",
|
51 |
+
"strip_accents": null,
|
52 |
+
"tokenize_chinese_chars": true,
|
53 |
+
"tokenizer_class": "DistilBertTokenizer",
|
54 |
+
"unk_token": "[UNK]"
|
55 |
+
}
|
ml_suite/models/fine_tuned_unsubscriber/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3f8e87068f27f96f378c34f96ec16c9a676bf0ae2c9a777803f6c459f8524ce9
|
3 |
+
size 4664
|
ml_suite/models/fine_tuned_unsubscriber/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
ml_suite/predictor.py
ADDED
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Predictor module for the Gmail Unsubscriber AI suite.
|
3 |
+
|
4 |
+
This module is responsible for:
|
5 |
+
- Loading the fine-tuned transformer model (base or personalized)
|
6 |
+
- Creating an efficient inference pipeline
|
7 |
+
- Providing real-time predictions for email classification
|
8 |
+
- Handling model loading errors gracefully
|
9 |
+
- Supporting personalized model selection based on user ID
|
10 |
+
|
11 |
+
The predictor provides a simple interface for the main application to classify
|
12 |
+
emails as "important" or "unsubscribable" with confidence scores, with the
|
13 |
+
ability to use personalized models when available.
|
14 |
+
"""
|
15 |
+
|
16 |
+
import os
|
17 |
+
import time
|
18 |
+
import logging
|
19 |
+
import torch
|
20 |
+
from typing import Dict, List, Optional, Any, Union
|
21 |
+
import traceback
|
22 |
+
|
23 |
+
# Hugging Face imports
|
24 |
+
from transformers import (
|
25 |
+
AutoModelForSequenceClassification,
|
26 |
+
AutoTokenizer,
|
27 |
+
TextClassificationPipeline,
|
28 |
+
PreTrainedModel,
|
29 |
+
PreTrainedTokenizer,
|
30 |
+
AutoConfig
|
31 |
+
)
|
32 |
+
|
33 |
+
# Local imports
|
34 |
+
from . import config
|
35 |
+
from . import utils
|
36 |
+
|
37 |
+
|
38 |
+
# Global variables to hold the loaded models and tokenizers
|
39 |
+
# Main base model pipeline
|
40 |
+
base_classification_pipeline = None
|
41 |
+
base_model_load_status = "Not Loaded"
|
42 |
+
base_model_load_error = None
|
43 |
+
base_model_last_load_attempt = 0
|
44 |
+
|
45 |
+
# Dictionary to store personalized model pipelines for different users
|
46 |
+
personalized_pipelines = {}
|
47 |
+
personalized_load_status = {}
|
48 |
+
|
49 |
+
# Configuration
|
50 |
+
load_cooldown_seconds = 60 # Wait at least this long between load attempts
|
51 |
+
|
52 |
+
|
53 |
+
def is_predictor_ready() -> bool:
|
54 |
+
"""
|
55 |
+
Check if the base predictor is ready for use.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
True if the base predictor is ready, False otherwise
|
59 |
+
"""
|
60 |
+
return base_model_load_status == "Ready" and base_classification_pipeline is not None
|
61 |
+
|
62 |
+
|
63 |
+
def get_model_info() -> Dict[str, Any]:
|
64 |
+
"""
|
65 |
+
Get information about the loaded model.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
Dictionary with model information
|
69 |
+
"""
|
70 |
+
# Try to load model info from the model directory
|
71 |
+
model_info_path = os.path.join(config.FINE_TUNED_MODEL_DIR, "model_info.txt")
|
72 |
+
model_info = {
|
73 |
+
"model_type": "base",
|
74 |
+
"model_path": config.FINE_TUNED_MODEL_DIR,
|
75 |
+
"trained_date": "unknown"
|
76 |
+
}
|
77 |
+
|
78 |
+
if os.path.exists(model_info_path):
|
79 |
+
try:
|
80 |
+
with open(model_info_path, 'r') as f:
|
81 |
+
lines = f.readlines()
|
82 |
+
for line in lines:
|
83 |
+
if ":" in line:
|
84 |
+
key, value = line.strip().split(":", 1)
|
85 |
+
key = key.strip().lower().replace(" ", "_")
|
86 |
+
value = value.strip()
|
87 |
+
model_info[key] = value
|
88 |
+
except:
|
89 |
+
pass
|
90 |
+
|
91 |
+
return model_info
|
92 |
+
|
93 |
+
|
94 |
+
def initialize_predictor(app_logger: logging.Logger) -> bool:
|
95 |
+
"""
|
96 |
+
Initialize the base predictor by loading the fine-tuned model.
|
97 |
+
|
98 |
+
This function:
|
99 |
+
1. Loads the tokenizer and model from the fine-tuned directory
|
100 |
+
2. Creates a TextClassificationPipeline for efficient inference
|
101 |
+
3. Sets global variables for status tracking
|
102 |
+
|
103 |
+
Args:
|
104 |
+
app_logger: Application logger for status and error reporting
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
True if initialization successful, False otherwise
|
108 |
+
"""
|
109 |
+
global base_classification_pipeline, base_model_load_status, base_model_load_error, base_model_last_load_attempt
|
110 |
+
|
111 |
+
# Reset error tracking
|
112 |
+
base_model_load_error = None
|
113 |
+
|
114 |
+
# Check if we attempted to load recently (to prevent repeated failures)
|
115 |
+
current_time = time.time()
|
116 |
+
if (current_time - base_model_last_load_attempt) < load_cooldown_seconds and base_model_load_status == "Failed":
|
117 |
+
app_logger.warning(
|
118 |
+
f"Not attempting to reload model - cooling down after recent failure. "
|
119 |
+
f"Will retry after {load_cooldown_seconds - (current_time - base_model_last_load_attempt):.0f} seconds."
|
120 |
+
)
|
121 |
+
return False
|
122 |
+
|
123 |
+
base_model_last_load_attempt = current_time
|
124 |
+
|
125 |
+
try:
|
126 |
+
# Update status
|
127 |
+
base_model_load_status = "Loading"
|
128 |
+
app_logger.info(f"Initializing base AI predictor from {config.FINE_TUNED_MODEL_DIR}")
|
129 |
+
|
130 |
+
# Check if model directory exists and contains necessary files
|
131 |
+
if not os.path.exists(config.FINE_TUNED_MODEL_DIR):
|
132 |
+
raise FileNotFoundError(f"Model directory not found: {config.FINE_TUNED_MODEL_DIR}")
|
133 |
+
|
134 |
+
# Load the tokenizer
|
135 |
+
app_logger.info("Loading tokenizer")
|
136 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
137 |
+
config.FINE_TUNED_MODEL_DIR,
|
138 |
+
local_files_only=True # Ensure we load locally without trying to download
|
139 |
+
)
|
140 |
+
|
141 |
+
# Load model configuration
|
142 |
+
model_config = AutoConfig.from_pretrained(
|
143 |
+
config.FINE_TUNED_MODEL_DIR,
|
144 |
+
local_files_only=True
|
145 |
+
)
|
146 |
+
|
147 |
+
# Check if the model has the expected number of labels
|
148 |
+
if model_config.num_labels != config.NUM_LABELS:
|
149 |
+
app_logger.warning(
|
150 |
+
f"Model has {model_config.num_labels} labels, "
|
151 |
+
f"but config specifies {config.NUM_LABELS} labels. "
|
152 |
+
f"This may cause issues with predictions."
|
153 |
+
)
|
154 |
+
|
155 |
+
# Load the model
|
156 |
+
app_logger.info("Loading model")
|
157 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
158 |
+
config.FINE_TUNED_MODEL_DIR,
|
159 |
+
local_files_only=True # Ensure we load locally without trying to download
|
160 |
+
)
|
161 |
+
|
162 |
+
# Determine device
|
163 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
164 |
+
app_logger.info(f"Using device: {device}")
|
165 |
+
|
166 |
+
# Move model to the device
|
167 |
+
model.to(device)
|
168 |
+
|
169 |
+
# Create classification pipeline
|
170 |
+
app_logger.info("Creating inference pipeline")
|
171 |
+
|
172 |
+
# Set device index properly
|
173 |
+
if device.type == "cuda":
|
174 |
+
device_index = 0 # Use first GPU
|
175 |
+
app_logger.info(f"Pipeline will use GPU device index: {device_index}")
|
176 |
+
else:
|
177 |
+
device_index = -1 # CPU
|
178 |
+
app_logger.info("Pipeline will use CPU")
|
179 |
+
|
180 |
+
base_classification_pipeline = TextClassificationPipeline(
|
181 |
+
model=model,
|
182 |
+
tokenizer=tokenizer,
|
183 |
+
device=device_index,
|
184 |
+
top_k=None, # Get probabilities for all classes (replaces deprecated return_all_scores)
|
185 |
+
function_to_apply="sigmoid" # Apply sigmoid to logits for probability interpretation
|
186 |
+
)
|
187 |
+
|
188 |
+
# Update status
|
189 |
+
base_model_load_status = "Ready"
|
190 |
+
app_logger.info("Base AI predictor initialized successfully")
|
191 |
+
|
192 |
+
return True
|
193 |
+
|
194 |
+
except Exception as e:
|
195 |
+
# Handle initialization failure
|
196 |
+
base_model_load_status = "Failed"
|
197 |
+
base_model_load_error = str(e)
|
198 |
+
error_traceback = traceback.format_exc()
|
199 |
+
|
200 |
+
app_logger.error(f"Error initializing base AI predictor: {str(e)}")
|
201 |
+
app_logger.debug(f"Traceback:\n{error_traceback}")
|
202 |
+
|
203 |
+
# Cleanup any partial loading
|
204 |
+
base_classification_pipeline = None
|
205 |
+
|
206 |
+
return False
|
207 |
+
|
208 |
+
|
209 |
+
def initialize_personalized_predictor(user_id: str, app_logger: logging.Logger) -> bool:
|
210 |
+
"""
|
211 |
+
Initialize a personalized predictor for a specific user.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
user_id: The user ID for which to load the personalized model
|
215 |
+
app_logger: Application logger for status and error reporting
|
216 |
+
|
217 |
+
Returns:
|
218 |
+
True if initialization successful, False otherwise
|
219 |
+
"""
|
220 |
+
global personalized_pipelines, personalized_load_status
|
221 |
+
|
222 |
+
try:
|
223 |
+
# Check if the user has a personalized model
|
224 |
+
model_dir = config.PERSONALIZED_MODEL_DIR_TEMPLATE.format(user_id=user_id)
|
225 |
+
if not os.path.exists(model_dir) or not os.path.exists(os.path.join(model_dir, "pytorch_model.bin")):
|
226 |
+
app_logger.info(f"No personalized model found for user {user_id}")
|
227 |
+
personalized_load_status[user_id] = "Not Available"
|
228 |
+
return False
|
229 |
+
|
230 |
+
# Update status
|
231 |
+
personalized_load_status[user_id] = "Loading"
|
232 |
+
app_logger.info(f"Initializing personalized AI predictor for user {user_id} from {model_dir}")
|
233 |
+
|
234 |
+
# Load the tokenizer
|
235 |
+
app_logger.info(f"Loading personalized tokenizer for user {user_id}")
|
236 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
237 |
+
model_dir,
|
238 |
+
local_files_only=True
|
239 |
+
)
|
240 |
+
|
241 |
+
# Load model configuration
|
242 |
+
model_config = AutoConfig.from_pretrained(
|
243 |
+
model_dir,
|
244 |
+
local_files_only=True
|
245 |
+
)
|
246 |
+
|
247 |
+
# Load the model
|
248 |
+
app_logger.info(f"Loading personalized model for user {user_id}")
|
249 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
250 |
+
model_dir,
|
251 |
+
local_files_only=True
|
252 |
+
)
|
253 |
+
|
254 |
+
# Determine device
|
255 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
256 |
+
|
257 |
+
# Move model to the device
|
258 |
+
model.to(device)
|
259 |
+
|
260 |
+
# Create classification pipeline
|
261 |
+
app_logger.info(f"Creating personalized inference pipeline for user {user_id}")
|
262 |
+
|
263 |
+
# Set device index properly
|
264 |
+
if device.type == "cuda":
|
265 |
+
device_index = 0 # Use first GPU
|
266 |
+
else:
|
267 |
+
device_index = -1 # CPU
|
268 |
+
|
269 |
+
personalized_pipelines[user_id] = TextClassificationPipeline(
|
270 |
+
model=model,
|
271 |
+
tokenizer=tokenizer,
|
272 |
+
device=device_index,
|
273 |
+
top_k=None, # Get probabilities for all classes (replaces deprecated return_all_scores)
|
274 |
+
function_to_apply="sigmoid"
|
275 |
+
)
|
276 |
+
|
277 |
+
# Update status
|
278 |
+
personalized_load_status[user_id] = "Ready"
|
279 |
+
app_logger.info(f"Personalized AI predictor for user {user_id} initialized successfully")
|
280 |
+
|
281 |
+
return True
|
282 |
+
|
283 |
+
except Exception as e:
|
284 |
+
# Handle initialization failure
|
285 |
+
personalized_load_status[user_id] = "Failed"
|
286 |
+
error_traceback = traceback.format_exc()
|
287 |
+
|
288 |
+
app_logger.error(f"Error initializing personalized AI predictor for user {user_id}: {str(e)}")
|
289 |
+
app_logger.debug(f"Traceback:\n{error_traceback}")
|
290 |
+
|
291 |
+
# Cleanup any partial loading
|
292 |
+
if user_id in personalized_pipelines:
|
293 |
+
del personalized_pipelines[user_id]
|
294 |
+
|
295 |
+
return False
|
296 |
+
|
297 |
+
|
298 |
+
def get_model_status(user_id: Optional[str] = None) -> Dict[str, Any]:
|
299 |
+
"""
|
300 |
+
Get the current status of the model.
|
301 |
+
|
302 |
+
Args:
|
303 |
+
user_id: Optional user ID to get status for personalized model
|
304 |
+
|
305 |
+
Returns:
|
306 |
+
Dictionary with status information
|
307 |
+
"""
|
308 |
+
if user_id is None:
|
309 |
+
# Return base model status
|
310 |
+
return {
|
311 |
+
"status": base_model_load_status,
|
312 |
+
"error": base_model_load_error,
|
313 |
+
"last_load_attempt": base_model_last_load_attempt,
|
314 |
+
"model_dir": config.FINE_TUNED_MODEL_DIR,
|
315 |
+
"is_ready": base_model_load_status == "Ready" and base_classification_pipeline is not None,
|
316 |
+
"is_personalized": False
|
317 |
+
}
|
318 |
+
else:
|
319 |
+
# Check if personalized model is available
|
320 |
+
if user_id not in personalized_load_status:
|
321 |
+
personalized_load_status[user_id] = "Not Loaded"
|
322 |
+
|
323 |
+
model_dir = config.PERSONALIZED_MODEL_DIR_TEMPLATE.format(user_id=user_id)
|
324 |
+
has_personalized = (
|
325 |
+
os.path.exists(model_dir) and
|
326 |
+
os.path.exists(os.path.join(model_dir, "pytorch_model.bin"))
|
327 |
+
)
|
328 |
+
|
329 |
+
return {
|
330 |
+
"status": personalized_load_status.get(user_id, "Not Loaded"),
|
331 |
+
"model_dir": model_dir if has_personalized else None,
|
332 |
+
"is_ready": personalized_load_status.get(user_id) == "Ready" and user_id in personalized_pipelines,
|
333 |
+
"is_personalized": True,
|
334 |
+
"has_personalized_model": has_personalized
|
335 |
+
}
|
336 |
+
|
337 |
+
|
338 |
+
def get_ai_prediction_for_email(
|
339 |
+
email_text_content: str,
|
340 |
+
user_id: Optional[str] = None,
|
341 |
+
app_logger: Optional[logging.Logger] = None
|
342 |
+
) -> Optional[Dict[str, Any]]:
|
343 |
+
"""
|
344 |
+
Get AI prediction for an email, optionally using a personalized model.
|
345 |
+
|
346 |
+
This function:
|
347 |
+
1. Checks if the requested model is loaded and ready
|
348 |
+
2. Passes the email text to the appropriate classification pipeline
|
349 |
+
3. Processes and returns the prediction results
|
350 |
+
|
351 |
+
Args:
|
352 |
+
email_text_content: The combined email text (subject + body)
|
353 |
+
user_id: Optional user ID to use personalized model if available
|
354 |
+
app_logger: Optional logger for error reporting
|
355 |
+
|
356 |
+
Returns:
|
357 |
+
Dictionary with prediction results (label, confidence, etc.) or None if prediction fails
|
358 |
+
"""
|
359 |
+
global base_classification_pipeline, personalized_pipelines
|
360 |
+
|
361 |
+
# Determine which pipeline to use
|
362 |
+
pipeline = None
|
363 |
+
using_personalized = False
|
364 |
+
|
365 |
+
if user_id is not None and user_id in personalized_pipelines:
|
366 |
+
# Try to use personalized model first
|
367 |
+
if personalized_load_status.get(user_id) == "Ready":
|
368 |
+
pipeline = personalized_pipelines[user_id]
|
369 |
+
using_personalized = True
|
370 |
+
|
371 |
+
# Fall back to base model if personalized isn't available
|
372 |
+
if pipeline is None:
|
373 |
+
if base_model_load_status != "Ready" or base_classification_pipeline is None:
|
374 |
+
return None
|
375 |
+
pipeline = base_classification_pipeline
|
376 |
+
|
377 |
+
try:
|
378 |
+
# Clean and normalize the input text
|
379 |
+
cleaned_text = utils.clean_text_for_model(
|
380 |
+
email_text_content,
|
381 |
+
max_length=config.EMAIL_SNIPPET_LENGTH_FOR_MODEL
|
382 |
+
)
|
383 |
+
|
384 |
+
# Skip prediction for extremely short text
|
385 |
+
if len(cleaned_text) < config.MIN_TEXT_LENGTH_FOR_TRAINING:
|
386 |
+
return {
|
387 |
+
"label": "INDETERMINATE",
|
388 |
+
"confidence": 0.0,
|
389 |
+
"predicted_id": None,
|
390 |
+
"error": "Text too short for reliable prediction",
|
391 |
+
"using_personalized_model": using_personalized
|
392 |
+
}
|
393 |
+
|
394 |
+
# Get prediction
|
395 |
+
predictions = pipeline(cleaned_text)
|
396 |
+
|
397 |
+
# Process prediction results
|
398 |
+
# The pipeline returns a list with one dictionary per input text
|
399 |
+
# Each dictionary has a 'label' and 'score' for each possible class
|
400 |
+
prediction_scores = {}
|
401 |
+
for pred in predictions[0]: # Get the first (and only) prediction
|
402 |
+
label_id = int(pred['label'].split('_')[-1])
|
403 |
+
label_name = config.ID_TO_LABEL_MAP.get(label_id)
|
404 |
+
prediction_scores[label_name] = pred['score']
|
405 |
+
|
406 |
+
# Find the highest scoring label
|
407 |
+
max_label = max(prediction_scores, key=prediction_scores.get)
|
408 |
+
max_score = prediction_scores[max_label]
|
409 |
+
predicted_id = config.LABEL_TO_ID_MAP.get(max_label)
|
410 |
+
|
411 |
+
# Apply confidence calibration to prevent overconfidence
|
412 |
+
# Temperature scaling to soften extreme predictions
|
413 |
+
temperature = 1.5 # Higher = less confident
|
414 |
+
calibrated_score = max_score ** (1 / temperature)
|
415 |
+
|
416 |
+
# Log prediction details for debugging
|
417 |
+
if app_logger:
|
418 |
+
app_logger.debug(f"AI Prediction: {max_label} (raw: {max_score:.3f}, calibrated: {calibrated_score:.3f})")
|
419 |
+
app_logger.debug(f"All scores: {prediction_scores}")
|
420 |
+
app_logger.debug(f"Email snippet: {cleaned_text[:100]}...")
|
421 |
+
|
422 |
+
# Return the prediction
|
423 |
+
return {
|
424 |
+
"label": max_label,
|
425 |
+
"confidence": calibrated_score,
|
426 |
+
"raw_confidence": max_score,
|
427 |
+
"predicted_id": predicted_id,
|
428 |
+
"all_scores": prediction_scores,
|
429 |
+
"using_personalized_model": using_personalized
|
430 |
+
}
|
431 |
+
|
432 |
+
except Exception as e:
|
433 |
+
# Log the error details but don't expose them in the response
|
434 |
+
if app_logger:
|
435 |
+
app_logger.error(f"Error during AI prediction: {str(e)}")
|
436 |
+
else:
|
437 |
+
print(f"Error during AI prediction: {str(e)}")
|
438 |
+
|
439 |
+
return {
|
440 |
+
"label": "ERROR",
|
441 |
+
"confidence": 0.0,
|
442 |
+
"predicted_id": None,
|
443 |
+
"error": "Prediction error occurred",
|
444 |
+
"using_personalized_model": using_personalized
|
445 |
+
}
|
ml_suite/task_utils.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for task status management and logging in the ML suite.
|
3 |
+
|
4 |
+
This module provides functions and classes for:
|
5 |
+
- Managing task status files for long-running operations
|
6 |
+
- Logging both to Flask app logs and status files
|
7 |
+
- Standardized error handling for AI operations
|
8 |
+
- Progress reporting for data preparation and model training
|
9 |
+
|
10 |
+
These utilities ensure that AI operations have proper status tracking,
|
11 |
+
enabling the frontend to display progress and results to users.
|
12 |
+
"""
|
13 |
+
|
14 |
+
import json
|
15 |
+
import time
|
16 |
+
import uuid
|
17 |
+
import datetime
|
18 |
+
import traceback
|
19 |
+
import os
|
20 |
+
import logging
|
21 |
+
from typing import Dict, List, Optional, Union, Any, Tuple
|
22 |
+
|
23 |
+
|
24 |
+
def get_current_timestamp() -> str:
|
25 |
+
"""Returns ISO format timestamp for current time."""
|
26 |
+
return datetime.datetime.now().isoformat()
|
27 |
+
|
28 |
+
|
29 |
+
def get_current_timestamp_log_prefix() -> str:
|
30 |
+
"""Returns a formatted timestamp string for log entries."""
|
31 |
+
return f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]"
|
32 |
+
|
33 |
+
|
34 |
+
def initialize_task_status_file(status_file_path: str, task_id: Optional[str] = None) -> Dict[str, Any]:
|
35 |
+
"""Initialize a new task status file with default values.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
status_file_path: Path to the status file
|
39 |
+
task_id: Optional unique ID for the task (UUID generated if None)
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
Dict containing the initialized status data
|
43 |
+
"""
|
44 |
+
if task_id is None:
|
45 |
+
task_id = str(uuid.uuid4())
|
46 |
+
|
47 |
+
status_data = {
|
48 |
+
"status": "pending",
|
49 |
+
"message": "Task initialized and pending execution",
|
50 |
+
"progress": 0.0,
|
51 |
+
"log": [],
|
52 |
+
"task_id": task_id,
|
53 |
+
"start_time": get_current_timestamp()
|
54 |
+
}
|
55 |
+
|
56 |
+
# Ensure the directory exists
|
57 |
+
dir_path = os.path.dirname(status_file_path)
|
58 |
+
if dir_path: # Only create directory if dirname returns a non-empty string
|
59 |
+
os.makedirs(dir_path, exist_ok=True)
|
60 |
+
|
61 |
+
with open(status_file_path, 'w') as f:
|
62 |
+
json.dump(status_data, f, indent=2)
|
63 |
+
|
64 |
+
return status_data
|
65 |
+
|
66 |
+
|
67 |
+
def update_task_status(
|
68 |
+
status_file_path: str,
|
69 |
+
status: Optional[str] = None,
|
70 |
+
message: Optional[str] = None,
|
71 |
+
log_entry: Optional[str] = None,
|
72 |
+
log_level: str = "info",
|
73 |
+
progress: Optional[float] = None,
|
74 |
+
result: Optional[Dict[str, Any]] = None,
|
75 |
+
error: Optional[Dict[str, Any]] = None
|
76 |
+
) -> Dict[str, Any]:
|
77 |
+
"""Update a task status file with new information.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
status_file_path: Path to the status file
|
81 |
+
status: Optional new status ('pending', 'in_progress', 'completed', 'failed')
|
82 |
+
message: Optional status message
|
83 |
+
log_entry: Optional log message to add to the log list
|
84 |
+
log_level: Log level ('info', 'warning', 'error', 'debug')
|
85 |
+
progress: Optional progress value (0.0 to 1.0)
|
86 |
+
result: Optional result data (for completed tasks)
|
87 |
+
error: Optional error data (for failed tasks)
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
Dict containing the updated status data
|
91 |
+
"""
|
92 |
+
try:
|
93 |
+
with open(status_file_path, 'r') as f:
|
94 |
+
status_data = json.load(f)
|
95 |
+
except (FileNotFoundError, json.JSONDecodeError):
|
96 |
+
# If file doesn't exist or is invalid, initialize it
|
97 |
+
status_data = initialize_task_status_file(status_file_path)
|
98 |
+
|
99 |
+
# Update fields if provided
|
100 |
+
if status is not None:
|
101 |
+
status_data["status"] = status
|
102 |
+
# If status is completed or failed, set end_time
|
103 |
+
if status in ("completed", "failed"):
|
104 |
+
status_data["end_time"] = get_current_timestamp()
|
105 |
+
|
106 |
+
if message is not None:
|
107 |
+
status_data["message"] = message
|
108 |
+
|
109 |
+
if log_entry is not None:
|
110 |
+
if "log" not in status_data:
|
111 |
+
status_data["log"] = []
|
112 |
+
status_data["log"].append({
|
113 |
+
"timestamp": get_current_timestamp(),
|
114 |
+
"level": log_level,
|
115 |
+
"message": log_entry
|
116 |
+
})
|
117 |
+
|
118 |
+
if progress is not None:
|
119 |
+
status_data["progress"] = max(0.0, min(1.0, float(progress))) # Ensure between 0-1
|
120 |
+
|
121 |
+
if result is not None and status_data.get("status") == "completed":
|
122 |
+
status_data["result"] = result
|
123 |
+
|
124 |
+
if error is not None and status_data.get("status") == "failed":
|
125 |
+
status_data["error"] = error
|
126 |
+
|
127 |
+
# Write updated status back to file
|
128 |
+
with open(status_file_path, 'w') as f:
|
129 |
+
json.dump(status_data, f, indent=2)
|
130 |
+
|
131 |
+
return status_data
|
132 |
+
|
133 |
+
|
134 |
+
def get_task_status(status_file_path: str) -> Dict[str, Any]:
|
135 |
+
"""Read and return the current task status.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
status_file_path: Path to the status file
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
Dict containing the current status data
|
142 |
+
"""
|
143 |
+
try:
|
144 |
+
with open(status_file_path, 'r') as f:
|
145 |
+
return json.load(f)
|
146 |
+
except (FileNotFoundError, json.JSONDecodeError):
|
147 |
+
# If status file doesn't exist or is invalid, return a default status
|
148 |
+
return {
|
149 |
+
"status": "unknown",
|
150 |
+
"message": "Task status unknown or not initialized",
|
151 |
+
"progress": 0.0,
|
152 |
+
"log": []
|
153 |
+
}
|
154 |
+
|
155 |
+
|
156 |
+
def log_task_error(status_file_path: str, error: Exception, message: str = "Task failed due to an error") -> Dict[str, Any]:
|
157 |
+
"""Log an error to the task status file.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
status_file_path: Path to the status file
|
161 |
+
error: The exception that occurred
|
162 |
+
message: Human-readable error message
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
Dict containing the updated status data
|
166 |
+
"""
|
167 |
+
error_info = {
|
168 |
+
"type": error.__class__.__name__,
|
169 |
+
"message": str(error),
|
170 |
+
"traceback": traceback.format_exc()
|
171 |
+
}
|
172 |
+
|
173 |
+
return update_task_status(
|
174 |
+
status_file_path=status_file_path,
|
175 |
+
status="failed",
|
176 |
+
message=message,
|
177 |
+
log_entry=f"ERROR: {message} - {error}",
|
178 |
+
log_level="error",
|
179 |
+
error=error_info
|
180 |
+
)
|
181 |
+
|
182 |
+
|
183 |
+
class AiTaskLogger:
|
184 |
+
"""Logger for AI tasks that updates both the application logger and task status file.
|
185 |
+
|
186 |
+
This logger provides methods for:
|
187 |
+
- Logging info, warning, and error messages
|
188 |
+
- Updating task progress
|
189 |
+
- Marking tasks as started, completed, or failed
|
190 |
+
- Ensuring consistent logging across both Flask app logs and task status files
|
191 |
+
"""
|
192 |
+
|
193 |
+
def __init__(self,
|
194 |
+
app_logger: logging.Logger,
|
195 |
+
status_file_path: str,
|
196 |
+
task_id: Optional[str] = None):
|
197 |
+
"""Initialize the task logger.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
app_logger: Flask application logger
|
201 |
+
status_file_path: Path to the task status file
|
202 |
+
task_id: Optional unique ID for the task (UUID generated if None)
|
203 |
+
"""
|
204 |
+
self.app_logger = app_logger
|
205 |
+
self.status_file_path = status_file_path
|
206 |
+
self.task_id = task_id or str(uuid.uuid4())
|
207 |
+
self.short_task_id = self.task_id[:8] # First 8 chars for log readability
|
208 |
+
|
209 |
+
# Initialize the status file
|
210 |
+
initialize_task_status_file(status_file_path, self.task_id)
|
211 |
+
|
212 |
+
def info(self, message: str, update_progress: Optional[float] = None) -> None:
|
213 |
+
"""Log an info message and optionally update progress.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
message: The message to log
|
217 |
+
update_progress: Optional progress value (0.0 to 1.0)
|
218 |
+
"""
|
219 |
+
self.app_logger.info(f"[AI Task {self.short_task_id}] {message}")
|
220 |
+
update_task_status(
|
221 |
+
self.status_file_path,
|
222 |
+
log_entry=message,
|
223 |
+
log_level="info",
|
224 |
+
progress=update_progress
|
225 |
+
)
|
226 |
+
|
227 |
+
def warning(self, message: str) -> None:
|
228 |
+
"""Log a warning message.
|
229 |
+
|
230 |
+
Args:
|
231 |
+
message: The warning message to log
|
232 |
+
"""
|
233 |
+
self.app_logger.warning(f"[AI Task {self.short_task_id}] {message}")
|
234 |
+
update_task_status(
|
235 |
+
self.status_file_path,
|
236 |
+
log_entry=message,
|
237 |
+
log_level="warning"
|
238 |
+
)
|
239 |
+
|
240 |
+
def error(self, message: str, error: Optional[Exception] = None) -> None:
|
241 |
+
"""Log an error message and optionally the exception details.
|
242 |
+
|
243 |
+
Args:
|
244 |
+
message: The error message to log
|
245 |
+
error: Optional exception that caused the error
|
246 |
+
"""
|
247 |
+
if error:
|
248 |
+
self.app_logger.error(f"[AI Task {self.short_task_id}] {message}: {error}", exc_info=True)
|
249 |
+
log_task_error(self.status_file_path, error, message)
|
250 |
+
else:
|
251 |
+
self.app_logger.error(f"[AI Task {self.short_task_id}] {message}")
|
252 |
+
update_task_status(
|
253 |
+
self.status_file_path,
|
254 |
+
log_entry=message,
|
255 |
+
log_level="error"
|
256 |
+
)
|
257 |
+
|
258 |
+
def start_task(self, message: str = "Task started") -> None:
|
259 |
+
"""Mark the task as started.
|
260 |
+
|
261 |
+
Args:
|
262 |
+
message: Optional message describing the task start
|
263 |
+
"""
|
264 |
+
self.app_logger.info(f"[AI Task {self.short_task_id}] {message}")
|
265 |
+
update_task_status(
|
266 |
+
self.status_file_path,
|
267 |
+
status="in_progress",
|
268 |
+
message=message,
|
269 |
+
log_entry=message,
|
270 |
+
log_level="info",
|
271 |
+
progress=0.0
|
272 |
+
)
|
273 |
+
|
274 |
+
def complete_task(self, message: str = "Task completed successfully", result: Optional[Dict[str, Any]] = None) -> None:
|
275 |
+
"""Mark the task as completed.
|
276 |
+
|
277 |
+
Args:
|
278 |
+
message: Optional completion message
|
279 |
+
result: Optional result data to store
|
280 |
+
"""
|
281 |
+
self.app_logger.info(f"[AI Task {self.short_task_id}] {message}")
|
282 |
+
update_task_status(
|
283 |
+
self.status_file_path,
|
284 |
+
status="completed",
|
285 |
+
message=message,
|
286 |
+
log_entry=message,
|
287 |
+
log_level="info",
|
288 |
+
progress=1.0,
|
289 |
+
result=result
|
290 |
+
)
|
291 |
+
|
292 |
+
def fail_task(self, message: str, error: Optional[Exception] = None) -> None:
|
293 |
+
"""Mark the task as failed.
|
294 |
+
|
295 |
+
Args:
|
296 |
+
message: Failure message
|
297 |
+
error: Optional exception that caused the failure
|
298 |
+
"""
|
299 |
+
if error:
|
300 |
+
self.app_logger.error(f"[AI Task {self.short_task_id}] {message}: {error}", exc_info=True)
|
301 |
+
log_task_error(self.status_file_path, error, message)
|
302 |
+
else:
|
303 |
+
self.app_logger.error(f"[AI Task {self.short_task_id}] {message}")
|
304 |
+
update_task_status(
|
305 |
+
self.status_file_path,
|
306 |
+
status="failed",
|
307 |
+
message=message,
|
308 |
+
log_entry=message,
|
309 |
+
log_level="error"
|
310 |
+
)
|
311 |
+
|
312 |
+
def update_progress(self, progress: float, message: Optional[str] = None) -> None:
|
313 |
+
"""Update the task progress.
|
314 |
+
|
315 |
+
Args:
|
316 |
+
progress: Progress value (0.0 to 1.0)
|
317 |
+
message: Optional progress message
|
318 |
+
"""
|
319 |
+
if message:
|
320 |
+
self.app_logger.info(f"[AI Task {self.short_task_id}] {message} (Progress: {progress:.1%})")
|
321 |
+
update_task_status(
|
322 |
+
self.status_file_path,
|
323 |
+
message=message,
|
324 |
+
log_entry=message,
|
325 |
+
log_level="info",
|
326 |
+
progress=progress
|
327 |
+
)
|
328 |
+
else:
|
329 |
+
update_task_status(
|
330 |
+
self.status_file_path,
|
331 |
+
progress=progress
|
332 |
+
)
|
ml_suite/utils.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Shared utilities for the ML suite.
|
3 |
+
|
4 |
+
This module provides shared functions used across the ML suite components:
|
5 |
+
- Email text analysis using heuristics (adapted from app.py)
|
6 |
+
- Text cleaning and normalization
|
7 |
+
- Timestamp and logging utilities
|
8 |
+
- HTML processing for email content extraction
|
9 |
+
|
10 |
+
These utilities ensure consistent processing across different components of the ML suite.
|
11 |
+
"""
|
12 |
+
|
13 |
+
import re
|
14 |
+
import os
|
15 |
+
import urllib.parse
|
16 |
+
import datetime
|
17 |
+
import html
|
18 |
+
from typing import Dict, List, Tuple, Optional, Any, Union
|
19 |
+
from bs4 import BeautifulSoup
|
20 |
+
|
21 |
+
|
22 |
+
# --- Email Heuristic Analysis ---
|
23 |
+
|
24 |
+
# Keywords that suggest an email is marketing/promotional/unsubscribable
|
25 |
+
UNSUBSCRIBE_KEYWORDS_FOR_AI_HEURISTICS = [
|
26 |
+
'unsubscribe', 'opt-out', 'opt out', 'stop receiving', 'manage preferences',
|
27 |
+
'email preferences', 'subscription', 'marketing', 'newsletter', 'promotional',
|
28 |
+
'offer', 'sale', 'discount', 'deal', 'coupon', 'promo code', 'promotion',
|
29 |
+
'limited time', 'subscribe', 'update preferences', 'mailing list',
|
30 |
+
'no longer wish to receive', 'manage subscriptions', 'manage your subscriptions'
|
31 |
+
]
|
32 |
+
|
33 |
+
# Keywords that suggest promotional content
|
34 |
+
PROMO_KEYWORDS_FOR_AI_HEURISTICS = [
|
35 |
+
'limited time', 'exclusive', 'offer', 'sale', 'discount', 'deal', 'coupon',
|
36 |
+
'promo code', 'promotion', 'savings', 'special offer', 'limited offer',
|
37 |
+
'buy now', 'shop now', 'order now', 'click here', 'purchase', 'buy',
|
38 |
+
'free shipping', 'free trial', 'new arrival', 'new product', 'flash sale'
|
39 |
+
]
|
40 |
+
|
41 |
+
# Common formatting patterns in promotional emails
|
42 |
+
FORMATTING_PATTERNS_FOR_AI_HEURISTICS = [
|
43 |
+
r'\*+\s*[A-Z]+\s*\*+', # ***TEXT***
|
44 |
+
r'\*\*[^*]+\*\*', # **TEXT**
|
45 |
+
r'!{2,}', # Multiple exclamation marks
|
46 |
+
r'\$\d+(\.\d{2})?(\s+off|\s+discount|%\s+off)', # Price patterns
|
47 |
+
r'\d+%\s+off', # Percentage discounts
|
48 |
+
r'SAVE\s+\d+%', # SAVE XX%
|
49 |
+
r'SAVE\s+\$\d+', # SAVE $XX
|
50 |
+
r'HURRY', # Urgency words
|
51 |
+
r'LIMITED TIME',
|
52 |
+
r'LAST CHANCE',
|
53 |
+
r'ENDING SOON'
|
54 |
+
]
|
55 |
+
|
56 |
+
|
57 |
+
def analyze_email_heuristics_for_ai(subject_text: str, snippet_text: str, list_unsubscribe_header: Optional[str] = None) -> Dict[str, bool]:
|
58 |
+
"""
|
59 |
+
Analyze email subject and body (snippet) text to determine if it's likely promotional/unsubscribable.
|
60 |
+
|
61 |
+
This function is adapted from the original heuristic analysis in app.py but modified
|
62 |
+
to be self-contained and not rely on Flask's app context. It examines the subject
|
63 |
+
and body for patterns common in promotional emails and subscription-based content.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
subject_text: The subject line of the email
|
67 |
+
snippet_text: A snippet of the email body text
|
68 |
+
list_unsubscribe_header: Optional List-Unsubscribe header value
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
Dict of boolean flags indicating different heuristic results:
|
72 |
+
{
|
73 |
+
'has_unsubscribe_text': bool, # Contains unsubscribe keywords
|
74 |
+
'has_promotional_keywords': bool, # Contains promotional keywords
|
75 |
+
'has_promotional_formatting': bool, # Contains typical promotional formatting
|
76 |
+
'has_list_unsubscribe_header': bool, # Has List-Unsubscribe header
|
77 |
+
'likely_unsubscribable': bool # Overall assessment
|
78 |
+
}
|
79 |
+
"""
|
80 |
+
# Ensure inputs are strings
|
81 |
+
subject_text = str(subject_text).lower() if subject_text else ""
|
82 |
+
snippet_text = str(snippet_text).lower() if snippet_text else ""
|
83 |
+
combined_text = f"{subject_text} {snippet_text}".lower()
|
84 |
+
|
85 |
+
# Initialize result with default values
|
86 |
+
result = {
|
87 |
+
'has_unsubscribe_text': False,
|
88 |
+
'has_promotional_keywords': False,
|
89 |
+
'has_promotional_formatting': False,
|
90 |
+
'has_list_unsubscribe_header': False,
|
91 |
+
'likely_unsubscribable': False
|
92 |
+
}
|
93 |
+
|
94 |
+
# Check for unsubscribe keywords
|
95 |
+
for keyword in UNSUBSCRIBE_KEYWORDS_FOR_AI_HEURISTICS:
|
96 |
+
if keyword.lower() in combined_text:
|
97 |
+
result['has_unsubscribe_text'] = True
|
98 |
+
break
|
99 |
+
|
100 |
+
# Check for promotional keywords
|
101 |
+
for keyword in PROMO_KEYWORDS_FOR_AI_HEURISTICS:
|
102 |
+
if keyword.lower() in combined_text:
|
103 |
+
result['has_promotional_keywords'] = True
|
104 |
+
break
|
105 |
+
|
106 |
+
# Check for promotional formatting patterns
|
107 |
+
combined_text_original_case = f"{subject_text} {snippet_text}" if subject_text and snippet_text else ""
|
108 |
+
for pattern in FORMATTING_PATTERNS_FOR_AI_HEURISTICS:
|
109 |
+
if re.search(pattern, combined_text_original_case, re.IGNORECASE):
|
110 |
+
result['has_promotional_formatting'] = True
|
111 |
+
break
|
112 |
+
|
113 |
+
# Check for List-Unsubscribe header
|
114 |
+
if list_unsubscribe_header:
|
115 |
+
result['has_list_unsubscribe_header'] = True
|
116 |
+
|
117 |
+
# Overall assessment: likely unsubscribable if any of the criteria are met
|
118 |
+
# For training data preparation, we want to be somewhat inclusive in what we label as potentially unsubscribable
|
119 |
+
result['likely_unsubscribable'] = any([
|
120 |
+
result['has_unsubscribe_text'],
|
121 |
+
(result['has_promotional_keywords'] and result['has_promotional_formatting']),
|
122 |
+
result['has_list_unsubscribe_header']
|
123 |
+
])
|
124 |
+
|
125 |
+
return result
|
126 |
+
|
127 |
+
|
128 |
+
# --- Text Cleaning Utilities ---
|
129 |
+
|
130 |
+
def clean_html_text(html_content: str) -> str:
|
131 |
+
"""
|
132 |
+
Clean HTML content and extract readable text.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
html_content: Raw HTML content string
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
Cleaned plain text extracted from HTML
|
139 |
+
"""
|
140 |
+
if not html_content:
|
141 |
+
return ""
|
142 |
+
|
143 |
+
try:
|
144 |
+
# Create BeautifulSoup object
|
145 |
+
soup = BeautifulSoup(html_content, 'html.parser')
|
146 |
+
|
147 |
+
# Remove script and style elements
|
148 |
+
for script_or_style in soup(['script', 'style', 'head', 'title', 'meta', '[document]']):
|
149 |
+
script_or_style.decompose()
|
150 |
+
|
151 |
+
# Get text content
|
152 |
+
text = soup.get_text()
|
153 |
+
|
154 |
+
# Clean up text: replace multiple newlines, spaces, etc.
|
155 |
+
text = re.sub(r'\n+', '\n', text)
|
156 |
+
text = re.sub(r' +', ' ', text)
|
157 |
+
text = text.strip()
|
158 |
+
|
159 |
+
return text
|
160 |
+
except Exception:
|
161 |
+
# If parsing fails, try to extract text with regex (fallback)
|
162 |
+
text = re.sub(r'<[^>]*>', ' ', html_content)
|
163 |
+
text = html.unescape(text)
|
164 |
+
text = re.sub(r'\s+', ' ', text)
|
165 |
+
return text.strip()
|
166 |
+
|
167 |
+
|
168 |
+
def normalize_spaces(text: str) -> str:
|
169 |
+
"""
|
170 |
+
Normalize whitespace in text.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
text: Input text
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
Text with normalized whitespace
|
177 |
+
"""
|
178 |
+
if not text:
|
179 |
+
return ""
|
180 |
+
|
181 |
+
# Replace newlines, tabs with spaces
|
182 |
+
text = re.sub(r'[\n\r\t]+', ' ', text)
|
183 |
+
# Collapse multiple spaces into one
|
184 |
+
text = re.sub(r' +', ' ', text)
|
185 |
+
return text.strip()
|
186 |
+
|
187 |
+
|
188 |
+
def normalize_urls(text: str) -> str:
|
189 |
+
"""
|
190 |
+
Replace URLs with a placeholder to reduce noise in training data.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
text: Input text
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
Text with URLs replaced by a placeholder
|
197 |
+
"""
|
198 |
+
if not text:
|
199 |
+
return ""
|
200 |
+
|
201 |
+
# URL regex pattern
|
202 |
+
url_pattern = r'(https?://[^\s]+)|(www\.[^\s]+\.[^\s]+)'
|
203 |
+
|
204 |
+
# Replace URLs with placeholder
|
205 |
+
return re.sub(url_pattern, '[URL]', text)
|
206 |
+
|
207 |
+
|
208 |
+
def clean_text_for_model(text: str, max_length: Optional[int] = None) -> str:
|
209 |
+
"""
|
210 |
+
Clean and normalize text for model input.
|
211 |
+
|
212 |
+
Args:
|
213 |
+
text: Input text (can be HTML or plain text)
|
214 |
+
max_length: Optional maximum length to truncate to
|
215 |
+
|
216 |
+
Returns:
|
217 |
+
Cleaned text ready for model input
|
218 |
+
"""
|
219 |
+
if not text:
|
220 |
+
return ""
|
221 |
+
|
222 |
+
# Check if input is likely HTML
|
223 |
+
if re.search(r'<\w+[^>]*>.*?</\w+>', text, re.DOTALL):
|
224 |
+
text = clean_html_text(text)
|
225 |
+
|
226 |
+
# Normalize whitespace
|
227 |
+
text = normalize_spaces(text)
|
228 |
+
|
229 |
+
# Replace URLs with placeholder
|
230 |
+
text = normalize_urls(text)
|
231 |
+
|
232 |
+
# Truncate if needed
|
233 |
+
if max_length and len(text) > max_length:
|
234 |
+
text = text[:max_length]
|
235 |
+
|
236 |
+
return text
|
237 |
+
|
238 |
+
|
239 |
+
# --- Timestamp and Path Utilities ---
|
240 |
+
|
241 |
+
def get_current_timestamp() -> str:
|
242 |
+
"""Returns ISO format timestamp for current time."""
|
243 |
+
return datetime.datetime.now().isoformat()
|
244 |
+
|
245 |
+
|
246 |
+
def get_current_timestamp_log_prefix() -> str:
|
247 |
+
"""Returns a formatted timestamp string for log entries."""
|
248 |
+
return f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]"
|
249 |
+
|
250 |
+
|
251 |
+
def ensure_directory_exists(directory_path: str) -> bool:
|
252 |
+
"""
|
253 |
+
Ensure that a directory exists, creating it if necessary.
|
254 |
+
|
255 |
+
Args:
|
256 |
+
directory_path: Path to the directory
|
257 |
+
|
258 |
+
Returns:
|
259 |
+
True if directory exists or was created, False on error
|
260 |
+
"""
|
261 |
+
try:
|
262 |
+
os.makedirs(directory_path, exist_ok=True)
|
263 |
+
return True
|
264 |
+
except Exception:
|
265 |
+
return False
|
266 |
+
|
267 |
+
|
268 |
+
# --- Email Header Analysis ---
|
269 |
+
|
270 |
+
def extract_email_addresses(header_value: str) -> List[str]:
|
271 |
+
"""
|
272 |
+
Extract email addresses from a header value.
|
273 |
+
|
274 |
+
Args:
|
275 |
+
header_value: Raw header value containing email addresses
|
276 |
+
|
277 |
+
Returns:
|
278 |
+
List of extracted email addresses
|
279 |
+
"""
|
280 |
+
if not header_value:
|
281 |
+
return []
|
282 |
+
|
283 |
+
# Basic email regex pattern
|
284 |
+
email_pattern = r'[\w.+-]+@[\w-]+\.[\w.-]+'
|
285 |
+
return re.findall(email_pattern, header_value)
|
286 |
+
|
287 |
+
|
288 |
+
def parse_list_unsubscribe_header(header_value: str) -> Dict[str, Any]:
|
289 |
+
"""
|
290 |
+
Parse the List-Unsubscribe header to extract URLs and email addresses.
|
291 |
+
|
292 |
+
Args:
|
293 |
+
header_value: Raw List-Unsubscribe header value
|
294 |
+
|
295 |
+
Returns:
|
296 |
+
Dict with extracted URLs and email addresses
|
297 |
+
"""
|
298 |
+
if not header_value:
|
299 |
+
return {"urls": [], "emails": []}
|
300 |
+
|
301 |
+
result = {"urls": [], "emails": []}
|
302 |
+
|
303 |
+
# Split by comma and process each value
|
304 |
+
for item in header_value.split(','):
|
305 |
+
item = item.strip()
|
306 |
+
|
307 |
+
# Handle <mailto:...> format
|
308 |
+
if item.startswith('<mailto:') and item.endswith('>'):
|
309 |
+
email = item[8:-1] # Remove <mailto: and >
|
310 |
+
result["emails"].append(email)
|
311 |
+
|
312 |
+
# Handle <http...> format
|
313 |
+
elif item.startswith('<http') and item.endswith('>'):
|
314 |
+
url = item[1:-1] # Remove < and >
|
315 |
+
result["urls"].append(url)
|
316 |
+
|
317 |
+
return result
|
prepare_deployment.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Script to prepare Hugging Face deployment without modifying the main app
|
4 |
+
|
5 |
+
echo "Preparing Hugging Face deployment..."
|
6 |
+
|
7 |
+
# Create symlinks to required directories (won't copy, just link)
|
8 |
+
ln -sf ../ml_suite ml_suite
|
9 |
+
ln -sf ../final_optimized_model final_optimized_model
|
10 |
+
|
11 |
+
echo "Created symbolic links to ml_suite and model directories"
|
12 |
+
echo ""
|
13 |
+
echo "To deploy to Hugging Face Spaces:"
|
14 |
+
echo "1. Create a new Space at https://huggingface.co/spaces"
|
15 |
+
echo "2. Clone the space repository"
|
16 |
+
echo "3. Copy the contents of this huggingface/ directory to the space repo"
|
17 |
+
echo "4. Copy the ml_suite/ and final_optimized_model/ directories"
|
18 |
+
echo "5. Git add, commit, and push"
|
19 |
+
echo ""
|
20 |
+
echo "Your main app remains completely untouched!"
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Core dependencies for Hugging Face deployment
|
2 |
+
gradio>=4.0.0
|
3 |
+
transformers==4.36.2
|
4 |
+
torch==2.2.2
|
5 |
+
accelerate==0.25.0
|
6 |
+
scikit-learn==1.4.2
|
7 |
+
pandas==2.2.1
|
8 |
+
nltk==3.8.1
|
9 |
+
beautifulsoup4>=4.12.0
|
10 |
+
requests>=2.30.0
|
11 |
+
joblib>=1.3.0
|