Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -23,6 +23,18 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
23 |
logger.info(f"Using device: {device}")
|
24 |
# Set up custom logging
|
25 |
# Set up custom logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
class CustomFormatter(logging.Formatter):
|
27 |
"""Custom formatter with colors and better formatting"""
|
28 |
grey = "\x1b[38;21m"
|
@@ -110,7 +122,7 @@ THRESHOLDS = {
|
|
110 |
|
111 |
|
112 |
PATTERN_WEIGHTS = {
|
113 |
-
"recovery": 0.7,
|
114 |
"control": 1.4,
|
115 |
"gaslighting": 1.3,
|
116 |
"guilt tripping": 1.2,
|
@@ -329,6 +341,18 @@ def get_risk_stage(patterns, sentiment):
|
|
329 |
logger.error(f"Error determining risk stage: {e}")
|
330 |
return 1
|
331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
@spaces.GPU
|
333 |
def compute_abuse_score(matched_scores, sentiment):
|
334 |
"""Compute abuse score from matched patterns and sentiment"""
|
@@ -440,6 +464,15 @@ def analyze_single_message(text, thresholds):
|
|
440 |
weight *= 1.5
|
441 |
matched_scores.append((label, score, weight))
|
442 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
443 |
# Get sentiment
|
444 |
sent_inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
445 |
sent_inputs = {k: v.to(device) for k, v in sent_inputs.items()}
|
@@ -452,7 +485,8 @@ def analyze_single_message(text, thresholds):
|
|
452 |
abuse_score = compute_abuse_score(matched_scores, sentiment)
|
453 |
if explicit_abuse:
|
454 |
abuse_score = max(abuse_score, 70.0)
|
455 |
-
|
|
|
456 |
# Get DARVO score
|
457 |
darvo_score = predict_darvo_score(text)
|
458 |
|
@@ -464,8 +498,9 @@ def analyze_single_message(text, thresholds):
|
|
464 |
logger.debug("Message classified as likely non-abusive (supportive, neutral, and obscure language). Returning low risk.")
|
465 |
return 0.0, [], [], {"label": "supportive"}, 1, 0.0, "neutral" # Return non-abusive values
|
466 |
|
|
|
467 |
# Set stage
|
468 |
-
stage = 2 if explicit_abuse or abuse_score > 70 else 1
|
469 |
|
470 |
logger.debug("=== DEBUG END ===\n")
|
471 |
|
@@ -664,7 +699,24 @@ def analyze_composite(msg1, msg2, msg3, *answers_and_none):
|
|
664 |
logger.debug(f" • {severity} | {label}: {score:.3f} (weight: {weight})")
|
665 |
else:
|
666 |
logger.debug("\n✓ No abuse patterns detected")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
667 |
|
|
|
|
|
|
|
668 |
# Extract scores and metadata
|
669 |
abuse_scores = [r[0][0] for r in results]
|
670 |
stages = [r[0][4] for r in results]
|
@@ -831,6 +883,9 @@ def analyze_composite(msg1, msg2, msg3, *answers_and_none):
|
|
831 |
"Low"
|
832 |
)
|
833 |
logger.debug(f"\n⚠️ Final Escalation Risk: {escalation_risk}")
|
|
|
|
|
|
|
834 |
# Generate Output Text
|
835 |
logger.debug("\n📝 GENERATING OUTPUT")
|
836 |
logger.debug("=" * 50)
|
|
|
23 |
logger.info(f"Using device: {device}")
|
24 |
# Set up custom logging
|
25 |
# Set up custom logging
|
26 |
+
# Add cleanup function here
|
27 |
+
def cleanup():
|
28 |
+
"""Cleanup function to free memory"""
|
29 |
+
try:
|
30 |
+
if torch.cuda.is_available():
|
31 |
+
torch.cuda.empty_cache()
|
32 |
+
logger.info("Cleanup completed")
|
33 |
+
except Exception as e:
|
34 |
+
logger.error(f"Error during cleanup: {e}")
|
35 |
+
|
36 |
+
atexit.register(cleanup)
|
37 |
+
|
38 |
class CustomFormatter(logging.Formatter):
|
39 |
"""Custom formatter with colors and better formatting"""
|
40 |
grey = "\x1b[38;21m"
|
|
|
122 |
|
123 |
|
124 |
PATTERN_WEIGHTS = {
|
125 |
+
"recovery phase": 0.7,
|
126 |
"control": 1.4,
|
127 |
"gaslighting": 1.3,
|
128 |
"guilt tripping": 1.2,
|
|
|
341 |
logger.error(f"Error determining risk stage: {e}")
|
342 |
return 1
|
343 |
|
344 |
+
def detect_compound_threat(patterns):
|
345 |
+
"""Helper function to standardize threat detection logic"""
|
346 |
+
pattern_set = set(p.lower() for p in patterns)
|
347 |
+
|
348 |
+
primary_threats = {"insults", "control"}
|
349 |
+
supporting_risks = {"gaslighting", "dismissiveness", "blame shifting"}
|
350 |
+
|
351 |
+
has_primary = bool(primary_threats & pattern_set)
|
352 |
+
has_supporting = bool(supporting_risks & pattern_set)
|
353 |
+
|
354 |
+
return has_primary and has_supporting
|
355 |
+
|
356 |
@spaces.GPU
|
357 |
def compute_abuse_score(matched_scores, sentiment):
|
358 |
"""Compute abuse score from matched patterns and sentiment"""
|
|
|
464 |
weight *= 1.5
|
465 |
matched_scores.append((label, score, weight))
|
466 |
|
467 |
+
logger.debug("\n🧨 SINGLE-MESSAGE COMPOUND THREAT CHECK")
|
468 |
+
compound_threat_flag = detect_compound_threat([label for label, _, _ in matched_scores])
|
469 |
+
compound_threat_boost = compound_threat_flag
|
470 |
+
|
471 |
+
if compound_threat_flag:
|
472 |
+
logger.debug("⚠️ Compound high-risk patterns detected in single message.")
|
473 |
+
else:
|
474 |
+
logger.debug("✓ No compound threats detected in message.")
|
475 |
+
|
476 |
# Get sentiment
|
477 |
sent_inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
478 |
sent_inputs = {k: v.to(device) for k, v in sent_inputs.items()}
|
|
|
485 |
abuse_score = compute_abuse_score(matched_scores, sentiment)
|
486 |
if explicit_abuse:
|
487 |
abuse_score = max(abuse_score, 70.0)
|
488 |
+
if compound_threat_boost:
|
489 |
+
abuse_score = max(abuse_score, 85.0) # force high score if compound risk detected
|
490 |
# Get DARVO score
|
491 |
darvo_score = predict_darvo_score(text)
|
492 |
|
|
|
498 |
logger.debug("Message classified as likely non-abusive (supportive, neutral, and obscure language). Returning low risk.")
|
499 |
return 0.0, [], [], {"label": "supportive"}, 1, 0.0, "neutral" # Return non-abusive values
|
500 |
|
501 |
+
|
502 |
# Set stage
|
503 |
+
stage = 2 if explicit_abuse or abuse_score > 70 or compound_threat_boost else 1
|
504 |
|
505 |
logger.debug("=== DEBUG END ===\n")
|
506 |
|
|
|
699 |
logger.debug(f" • {severity} | {label}: {score:.3f} (weight: {weight})")
|
700 |
else:
|
701 |
logger.debug("\n✓ No abuse patterns detected")
|
702 |
+
# Compound Threat Detection (across all active messages)
|
703 |
+
logger.debug("\n🧨 COMPOUND THREAT DETECTION")
|
704 |
+
logger.debug("=" * 50)
|
705 |
+
|
706 |
+
compound_threat_flag = any(
|
707 |
+
detect_compound_threat(result[0][1]) # patterns are in result[0][1]
|
708 |
+
for result in results
|
709 |
+
)
|
710 |
+
|
711 |
+
if compound_threat_flag:
|
712 |
+
logger.debug("⚠️ ALERT: Compound threat patterns detected across messages")
|
713 |
+
logger.debug(" • High-risk combination of primary and supporting patterns found")
|
714 |
+
else:
|
715 |
+
logger.debug("✓ No compound threats detected across messages")
|
716 |
|
717 |
+
logger.debug(f" • Threat messages: {threat_count}")
|
718 |
+
logger.debug(f" • Supporting risk messages: {supporting_risk_count}")
|
719 |
+
logger.debug(f" • Compound threat triggered: {'YES' if compound_threat_flag else 'NO'}")
|
720 |
# Extract scores and metadata
|
721 |
abuse_scores = [r[0][0] for r in results]
|
722 |
stages = [r[0][4] for r in results]
|
|
|
883 |
"Low"
|
884 |
)
|
885 |
logger.debug(f"\n⚠️ Final Escalation Risk: {escalation_risk}")
|
886 |
+
if compound_threat_flag:
|
887 |
+
escalation_risk = "Critical"
|
888 |
+
logger.debug("❗ Compound threat detected — overriding risk to CRITICAL.")
|
889 |
# Generate Output Text
|
890 |
logger.debug("\n📝 GENERATING OUTPUT")
|
891 |
logger.debug("=" * 50)
|