Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -12,7 +12,7 @@ from torch.nn.functional import sigmoid
|
|
12 |
from collections import Counter
|
13 |
import logging
|
14 |
import traceback
|
15 |
-
|
16 |
|
17 |
# Set up logging
|
18 |
logging.basicConfig(level=logging.DEBUG)
|
@@ -23,18 +23,6 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
23 |
logger.info(f"Using device: {device}")
|
24 |
# Set up custom logging
|
25 |
# Set up custom logging
|
26 |
-
# Add cleanup function here
|
27 |
-
def cleanup():
|
28 |
-
"""Cleanup function to free memory"""
|
29 |
-
try:
|
30 |
-
if torch.cuda.is_available():
|
31 |
-
torch.cuda.empty_cache()
|
32 |
-
logger.info("Cleanup completed")
|
33 |
-
except Exception as e:
|
34 |
-
logger.error(f"Error during cleanup: {e}")
|
35 |
-
|
36 |
-
atexit.register(cleanup)
|
37 |
-
|
38 |
class CustomFormatter(logging.Formatter):
|
39 |
"""Custom formatter with colors and better formatting"""
|
40 |
grey = "\x1b[38;21m"
|
@@ -122,7 +110,7 @@ THRESHOLDS = {
|
|
122 |
|
123 |
|
124 |
PATTERN_WEIGHTS = {
|
125 |
-
"recovery
|
126 |
"control": 1.4,
|
127 |
"gaslighting": 1.3,
|
128 |
"guilt tripping": 1.2,
|
@@ -341,32 +329,6 @@ def get_risk_stage(patterns, sentiment):
|
|
341 |
logger.error(f"Error determining risk stage: {e}")
|
342 |
return 1
|
343 |
|
344 |
-
def detect_compound_threat(patterns):
|
345 |
-
"""Helper function to standardize threat detection logic"""
|
346 |
-
pattern_set = set(p.lower() for p in patterns)
|
347 |
-
|
348 |
-
# Expand primary threats to include more patterns
|
349 |
-
primary_threats = {"control", "insults", "blame shifting"}
|
350 |
-
|
351 |
-
# Expand supporting risks to include more patterns
|
352 |
-
supporting_risks = {
|
353 |
-
"gaslighting",
|
354 |
-
"dismissiveness",
|
355 |
-
"blame shifting",
|
356 |
-
"guilt tripping",
|
357 |
-
"contradictory statements"
|
358 |
-
}
|
359 |
-
|
360 |
-
has_primary = bool(primary_threats & pattern_set)
|
361 |
-
has_supporting = len(supporting_risks & pattern_set) >= 1 # Only need one supporting risk
|
362 |
-
|
363 |
-
# Additional check for control + any other pattern
|
364 |
-
if "control" in pattern_set and len(pattern_set) > 1:
|
365 |
-
return True
|
366 |
-
|
367 |
-
return has_primary and has_supporting
|
368 |
-
|
369 |
-
|
370 |
@spaces.GPU
|
371 |
def compute_abuse_score(matched_scores, sentiment):
|
372 |
"""Compute abuse score from matched patterns and sentiment"""
|
@@ -429,7 +391,7 @@ def analyze_single_message(text, thresholds):
|
|
429 |
return 0.0, [], [], {"label": "none"}, 1, 0.0, None
|
430 |
|
431 |
# Check for explicit abuse
|
432 |
-
explicit_abuse_words = ['fuck', 'bitch', 'shit', 'ass', 'dick'
|
433 |
explicit_abuse = any(word in text.lower() for word in explicit_abuse_words)
|
434 |
logger.debug(f"Explicit abuse detected: {explicit_abuse}")
|
435 |
|
@@ -458,43 +420,26 @@ def analyze_single_message(text, thresholds):
|
|
458 |
if explicit_abuse:
|
459 |
threshold_labels.append("insults")
|
460 |
logger.debug("\nForced inclusion of 'insults' due to explicit abuse")
|
461 |
-
|
462 |
for label, score in sorted_predictions:
|
463 |
-
if label == "nonabusive":
|
464 |
-
continue # Skip nonabusive label
|
465 |
base_threshold = thresholds.get(label, 0.25)
|
466 |
if explicit_abuse:
|
467 |
base_threshold *= 0.5
|
468 |
if score > base_threshold:
|
469 |
if label not in threshold_labels:
|
470 |
threshold_labels.append(label)
|
|
|
|
|
471 |
|
472 |
-
#
|
473 |
-
control_score = raw_scores[LABELS.index("control")]
|
474 |
-
if control_score > 0.3 and "control" not in threshold_labels: # Lower threshold for control
|
475 |
-
threshold_labels.append("control")
|
476 |
-
logger.debug("\nAdded control pattern due to high control score")
|
477 |
-
|
478 |
-
# Calculate matched scores (exclude nonabusive)
|
479 |
matched_scores = []
|
480 |
for label in threshold_labels:
|
481 |
-
if label == "nonabusive":
|
482 |
-
continue
|
483 |
score = raw_scores[LABELS.index(label)]
|
484 |
weight = PATTERN_WEIGHTS.get(label, 1.0)
|
485 |
if explicit_abuse and label == "insults":
|
486 |
weight *= 1.5
|
487 |
matched_scores.append((label, score, weight))
|
488 |
|
489 |
-
logger.debug("\n🧨 SINGLE-MESSAGE COMPOUND THREAT CHECK")
|
490 |
-
compound_threat_flag = detect_compound_threat([label for label, _, _ in matched_scores])
|
491 |
-
compound_threat_boost = compound_threat_flag
|
492 |
-
|
493 |
-
if compound_threat_flag:
|
494 |
-
logger.debug("⚠️ Compound high-risk patterns detected in single message.")
|
495 |
-
else:
|
496 |
-
logger.debug("✓ No compound threats detected in message.")
|
497 |
-
|
498 |
# Get sentiment
|
499 |
sent_inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
500 |
sent_inputs = {k: v.to(device) for k, v in sent_inputs.items()}
|
@@ -507,25 +452,20 @@ def analyze_single_message(text, thresholds):
|
|
507 |
abuse_score = compute_abuse_score(matched_scores, sentiment)
|
508 |
if explicit_abuse:
|
509 |
abuse_score = max(abuse_score, 70.0)
|
510 |
-
|
511 |
-
abuse_score = max(abuse_score, 85.0) # force high score if compound risk detected
|
512 |
-
if "control" in [label for label, _, _ in matched_scores]:
|
513 |
-
abuse_score = max(abuse_score, 70.0) # Minimum score for control patterns
|
514 |
-
|
515 |
# Get DARVO score
|
516 |
darvo_score = predict_darvo_score(text)
|
517 |
|
518 |
# Get tone using emotion-based approach
|
519 |
tone_tag = get_emotional_tone_tag(text, sentiment, threshold_labels, abuse_score)
|
520 |
-
|
521 |
-
|
522 |
-
highest_pattern = max(matched_scores, key=lambda x: x[1])[0] if matched_scores else None
|
523 |
if sentiment == "supportive" and tone_tag == "neutral" and highest_pattern == "obscure language":
|
524 |
logger.debug("Message classified as likely non-abusive (supportive, neutral, and obscure language). Returning low risk.")
|
525 |
return 0.0, [], [], {"label": "supportive"}, 1, 0.0, "neutral" # Return non-abusive values
|
526 |
|
527 |
# Set stage
|
528 |
-
stage = 2 if explicit_abuse or abuse_score > 70
|
529 |
|
530 |
logger.debug("=== DEBUG END ===\n")
|
531 |
|
@@ -534,7 +474,6 @@ def analyze_single_message(text, thresholds):
|
|
534 |
except Exception as e:
|
535 |
logger.error(f"Error in analyze_single_message: {e}")
|
536 |
return 0.0, [], [], {"label": "error"}, 1, 0.0, None
|
537 |
-
|
538 |
def generate_abuse_score_chart(dates, scores, patterns):
|
539 |
"""Generate a timeline chart of abuse scores"""
|
540 |
try:
|
@@ -725,22 +664,7 @@ def analyze_composite(msg1, msg2, msg3, *answers_and_none):
|
|
725 |
logger.debug(f" • {severity} | {label}: {score:.3f} (weight: {weight})")
|
726 |
else:
|
727 |
logger.debug("\n✓ No abuse patterns detected")
|
728 |
-
# Compound Threat Detection (across all active messages)
|
729 |
-
logger.debug("\n🧨 COMPOUND THREAT DETECTION")
|
730 |
-
logger.debug("=" * 50)
|
731 |
-
|
732 |
-
compound_threat_flag = any(
|
733 |
-
detect_compound_threat(result[0][1]) # patterns are in result[0][1]
|
734 |
-
for result in results
|
735 |
-
)
|
736 |
-
|
737 |
-
if compound_threat_flag:
|
738 |
-
logger.debug("⚠️ ALERT: Compound threat patterns detected across messages")
|
739 |
-
logger.debug(" • High-risk combination of primary and supporting patterns found")
|
740 |
-
else:
|
741 |
-
logger.debug("✓ No compound threats detected across messages")
|
742 |
|
743 |
-
|
744 |
# Extract scores and metadata
|
745 |
abuse_scores = [r[0][0] for r in results]
|
746 |
stages = [r[0][4] for r in results]
|
@@ -907,9 +831,6 @@ def analyze_composite(msg1, msg2, msg3, *answers_and_none):
|
|
907 |
"Low"
|
908 |
)
|
909 |
logger.debug(f"\n⚠️ Final Escalation Risk: {escalation_risk}")
|
910 |
-
if compound_threat_flag:
|
911 |
-
escalation_risk = "Critical"
|
912 |
-
logger.debug("❗ Compound threat detected — overriding risk to CRITICAL.")
|
913 |
# Generate Output Text
|
914 |
logger.debug("\n📝 GENERATING OUTPUT")
|
915 |
logger.debug("=" * 50)
|
|
|
12 |
from collections import Counter
|
13 |
import logging
|
14 |
import traceback
|
15 |
+
|
16 |
|
17 |
# Set up logging
|
18 |
logging.basicConfig(level=logging.DEBUG)
|
|
|
23 |
logger.info(f"Using device: {device}")
|
24 |
# Set up custom logging
|
25 |
# Set up custom logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
class CustomFormatter(logging.Formatter):
|
27 |
"""Custom formatter with colors and better formatting"""
|
28 |
grey = "\x1b[38;21m"
|
|
|
110 |
|
111 |
|
112 |
PATTERN_WEIGHTS = {
|
113 |
+
"recovery": 0.7,
|
114 |
"control": 1.4,
|
115 |
"gaslighting": 1.3,
|
116 |
"guilt tripping": 1.2,
|
|
|
329 |
logger.error(f"Error determining risk stage: {e}")
|
330 |
return 1
|
331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
@spaces.GPU
|
333 |
def compute_abuse_score(matched_scores, sentiment):
|
334 |
"""Compute abuse score from matched patterns and sentiment"""
|
|
|
391 |
return 0.0, [], [], {"label": "none"}, 1, 0.0, None
|
392 |
|
393 |
# Check for explicit abuse
|
394 |
+
explicit_abuse_words = ['fuck', 'bitch', 'shit', 'ass', 'dick']
|
395 |
explicit_abuse = any(word in text.lower() for word in explicit_abuse_words)
|
396 |
logger.debug(f"Explicit abuse detected: {explicit_abuse}")
|
397 |
|
|
|
420 |
if explicit_abuse:
|
421 |
threshold_labels.append("insults")
|
422 |
logger.debug("\nForced inclusion of 'insults' due to explicit abuse")
|
423 |
+
|
424 |
for label, score in sorted_predictions:
|
|
|
|
|
425 |
base_threshold = thresholds.get(label, 0.25)
|
426 |
if explicit_abuse:
|
427 |
base_threshold *= 0.5
|
428 |
if score > base_threshold:
|
429 |
if label not in threshold_labels:
|
430 |
threshold_labels.append(label)
|
431 |
+
|
432 |
+
logger.debug("\nLabels that passed thresholds:", threshold_labels)
|
433 |
|
434 |
+
# Calculate matched scores
|
|
|
|
|
|
|
|
|
|
|
|
|
435 |
matched_scores = []
|
436 |
for label in threshold_labels:
|
|
|
|
|
437 |
score = raw_scores[LABELS.index(label)]
|
438 |
weight = PATTERN_WEIGHTS.get(label, 1.0)
|
439 |
if explicit_abuse and label == "insults":
|
440 |
weight *= 1.5
|
441 |
matched_scores.append((label, score, weight))
|
442 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
443 |
# Get sentiment
|
444 |
sent_inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
445 |
sent_inputs = {k: v.to(device) for k, v in sent_inputs.items()}
|
|
|
452 |
abuse_score = compute_abuse_score(matched_scores, sentiment)
|
453 |
if explicit_abuse:
|
454 |
abuse_score = max(abuse_score, 70.0)
|
455 |
+
|
|
|
|
|
|
|
|
|
456 |
# Get DARVO score
|
457 |
darvo_score = predict_darvo_score(text)
|
458 |
|
459 |
# Get tone using emotion-based approach
|
460 |
tone_tag = get_emotional_tone_tag(text, sentiment, threshold_labels, abuse_score)
|
461 |
+
# Check for the specific combination
|
462 |
+
highest_pattern = max(matched_scores, key=lambda x: x[1])[0] if matched_scores else None # Get highest pattern
|
|
|
463 |
if sentiment == "supportive" and tone_tag == "neutral" and highest_pattern == "obscure language":
|
464 |
logger.debug("Message classified as likely non-abusive (supportive, neutral, and obscure language). Returning low risk.")
|
465 |
return 0.0, [], [], {"label": "supportive"}, 1, 0.0, "neutral" # Return non-abusive values
|
466 |
|
467 |
# Set stage
|
468 |
+
stage = 2 if explicit_abuse or abuse_score > 70 else 1
|
469 |
|
470 |
logger.debug("=== DEBUG END ===\n")
|
471 |
|
|
|
474 |
except Exception as e:
|
475 |
logger.error(f"Error in analyze_single_message: {e}")
|
476 |
return 0.0, [], [], {"label": "error"}, 1, 0.0, None
|
|
|
477 |
def generate_abuse_score_chart(dates, scores, patterns):
|
478 |
"""Generate a timeline chart of abuse scores"""
|
479 |
try:
|
|
|
664 |
logger.debug(f" • {severity} | {label}: {score:.3f} (weight: {weight})")
|
665 |
else:
|
666 |
logger.debug("\n✓ No abuse patterns detected")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
667 |
|
|
|
668 |
# Extract scores and metadata
|
669 |
abuse_scores = [r[0][0] for r in results]
|
670 |
stages = [r[0][4] for r in results]
|
|
|
831 |
"Low"
|
832 |
)
|
833 |
logger.debug(f"\n⚠️ Final Escalation Risk: {escalation_risk}")
|
|
|
|
|
|
|
834 |
# Generate Output Text
|
835 |
logger.debug("\n📝 GENERATING OUTPUT")
|
836 |
logger.debug("=" * 50)
|