SamanthaStorm commited on
Commit
025be57
·
verified ·
1 Parent(s): 6d45d2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -90
app.py CHANGED
@@ -12,7 +12,7 @@ from torch.nn.functional import sigmoid
12
  from collections import Counter
13
  import logging
14
  import traceback
15
- import atexit
16
 
17
  # Set up logging
18
  logging.basicConfig(level=logging.DEBUG)
@@ -23,18 +23,6 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
  logger.info(f"Using device: {device}")
24
  # Set up custom logging
25
  # Set up custom logging
26
- # Add cleanup function here
27
- def cleanup():
28
- """Cleanup function to free memory"""
29
- try:
30
- if torch.cuda.is_available():
31
- torch.cuda.empty_cache()
32
- logger.info("Cleanup completed")
33
- except Exception as e:
34
- logger.error(f"Error during cleanup: {e}")
35
-
36
- atexit.register(cleanup)
37
-
38
  class CustomFormatter(logging.Formatter):
39
  """Custom formatter with colors and better formatting"""
40
  grey = "\x1b[38;21m"
@@ -122,7 +110,7 @@ THRESHOLDS = {
122
 
123
 
124
  PATTERN_WEIGHTS = {
125
- "recovery phase": 0.7,
126
  "control": 1.4,
127
  "gaslighting": 1.3,
128
  "guilt tripping": 1.2,
@@ -341,32 +329,6 @@ def get_risk_stage(patterns, sentiment):
341
  logger.error(f"Error determining risk stage: {e}")
342
  return 1
343
 
344
- def detect_compound_threat(patterns):
345
- """Helper function to standardize threat detection logic"""
346
- pattern_set = set(p.lower() for p in patterns)
347
-
348
- # Expand primary threats to include more patterns
349
- primary_threats = {"control", "insults", "blame shifting"}
350
-
351
- # Expand supporting risks to include more patterns
352
- supporting_risks = {
353
- "gaslighting",
354
- "dismissiveness",
355
- "blame shifting",
356
- "guilt tripping",
357
- "contradictory statements"
358
- }
359
-
360
- has_primary = bool(primary_threats & pattern_set)
361
- has_supporting = len(supporting_risks & pattern_set) >= 1 # Only need one supporting risk
362
-
363
- # Additional check for control + any other pattern
364
- if "control" in pattern_set and len(pattern_set) > 1:
365
- return True
366
-
367
- return has_primary and has_supporting
368
-
369
-
370
  @spaces.GPU
371
  def compute_abuse_score(matched_scores, sentiment):
372
  """Compute abuse score from matched patterns and sentiment"""
@@ -429,7 +391,7 @@ def analyze_single_message(text, thresholds):
429
  return 0.0, [], [], {"label": "none"}, 1, 0.0, None
430
 
431
  # Check for explicit abuse
432
- explicit_abuse_words = ['fuck', 'bitch', 'shit', 'ass', 'dick', "make you regret it"]
433
  explicit_abuse = any(word in text.lower() for word in explicit_abuse_words)
434
  logger.debug(f"Explicit abuse detected: {explicit_abuse}")
435
 
@@ -458,43 +420,26 @@ def analyze_single_message(text, thresholds):
458
  if explicit_abuse:
459
  threshold_labels.append("insults")
460
  logger.debug("\nForced inclusion of 'insults' due to explicit abuse")
461
-
462
  for label, score in sorted_predictions:
463
- if label == "nonabusive":
464
- continue # Skip nonabusive label
465
  base_threshold = thresholds.get(label, 0.25)
466
  if explicit_abuse:
467
  base_threshold *= 0.5
468
  if score > base_threshold:
469
  if label not in threshold_labels:
470
  threshold_labels.append(label)
 
 
471
 
472
- # Check for control pattern
473
- control_score = raw_scores[LABELS.index("control")]
474
- if control_score > 0.3 and "control" not in threshold_labels: # Lower threshold for control
475
- threshold_labels.append("control")
476
- logger.debug("\nAdded control pattern due to high control score")
477
-
478
- # Calculate matched scores (exclude nonabusive)
479
  matched_scores = []
480
  for label in threshold_labels:
481
- if label == "nonabusive":
482
- continue
483
  score = raw_scores[LABELS.index(label)]
484
  weight = PATTERN_WEIGHTS.get(label, 1.0)
485
  if explicit_abuse and label == "insults":
486
  weight *= 1.5
487
  matched_scores.append((label, score, weight))
488
 
489
- logger.debug("\n🧨 SINGLE-MESSAGE COMPOUND THREAT CHECK")
490
- compound_threat_flag = detect_compound_threat([label for label, _, _ in matched_scores])
491
- compound_threat_boost = compound_threat_flag
492
-
493
- if compound_threat_flag:
494
- logger.debug("⚠️ Compound high-risk patterns detected in single message.")
495
- else:
496
- logger.debug("✓ No compound threats detected in message.")
497
-
498
  # Get sentiment
499
  sent_inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
500
  sent_inputs = {k: v.to(device) for k, v in sent_inputs.items()}
@@ -507,25 +452,20 @@ def analyze_single_message(text, thresholds):
507
  abuse_score = compute_abuse_score(matched_scores, sentiment)
508
  if explicit_abuse:
509
  abuse_score = max(abuse_score, 70.0)
510
- if compound_threat_boost:
511
- abuse_score = max(abuse_score, 85.0) # force high score if compound risk detected
512
- if "control" in [label for label, _, _ in matched_scores]:
513
- abuse_score = max(abuse_score, 70.0) # Minimum score for control patterns
514
-
515
  # Get DARVO score
516
  darvo_score = predict_darvo_score(text)
517
 
518
  # Get tone using emotion-based approach
519
  tone_tag = get_emotional_tone_tag(text, sentiment, threshold_labels, abuse_score)
520
-
521
- # Check for the specific combination
522
- highest_pattern = max(matched_scores, key=lambda x: x[1])[0] if matched_scores else None
523
  if sentiment == "supportive" and tone_tag == "neutral" and highest_pattern == "obscure language":
524
  logger.debug("Message classified as likely non-abusive (supportive, neutral, and obscure language). Returning low risk.")
525
  return 0.0, [], [], {"label": "supportive"}, 1, 0.0, "neutral" # Return non-abusive values
526
 
527
  # Set stage
528
- stage = 2 if explicit_abuse or abuse_score > 70 or compound_threat_boost else 1
529
 
530
  logger.debug("=== DEBUG END ===\n")
531
 
@@ -534,7 +474,6 @@ def analyze_single_message(text, thresholds):
534
  except Exception as e:
535
  logger.error(f"Error in analyze_single_message: {e}")
536
  return 0.0, [], [], {"label": "error"}, 1, 0.0, None
537
-
538
  def generate_abuse_score_chart(dates, scores, patterns):
539
  """Generate a timeline chart of abuse scores"""
540
  try:
@@ -725,22 +664,7 @@ def analyze_composite(msg1, msg2, msg3, *answers_and_none):
725
  logger.debug(f" • {severity} | {label}: {score:.3f} (weight: {weight})")
726
  else:
727
  logger.debug("\n✓ No abuse patterns detected")
728
- # Compound Threat Detection (across all active messages)
729
- logger.debug("\n🧨 COMPOUND THREAT DETECTION")
730
- logger.debug("=" * 50)
731
-
732
- compound_threat_flag = any(
733
- detect_compound_threat(result[0][1]) # patterns are in result[0][1]
734
- for result in results
735
- )
736
-
737
- if compound_threat_flag:
738
- logger.debug("⚠️ ALERT: Compound threat patterns detected across messages")
739
- logger.debug(" • High-risk combination of primary and supporting patterns found")
740
- else:
741
- logger.debug("✓ No compound threats detected across messages")
742
 
743
-
744
  # Extract scores and metadata
745
  abuse_scores = [r[0][0] for r in results]
746
  stages = [r[0][4] for r in results]
@@ -907,9 +831,6 @@ def analyze_composite(msg1, msg2, msg3, *answers_and_none):
907
  "Low"
908
  )
909
  logger.debug(f"\n⚠️ Final Escalation Risk: {escalation_risk}")
910
- if compound_threat_flag:
911
- escalation_risk = "Critical"
912
- logger.debug("❗ Compound threat detected — overriding risk to CRITICAL.")
913
  # Generate Output Text
914
  logger.debug("\n📝 GENERATING OUTPUT")
915
  logger.debug("=" * 50)
 
12
  from collections import Counter
13
  import logging
14
  import traceback
15
+
16
 
17
  # Set up logging
18
  logging.basicConfig(level=logging.DEBUG)
 
23
  logger.info(f"Using device: {device}")
24
  # Set up custom logging
25
  # Set up custom logging
 
 
 
 
 
 
 
 
 
 
 
 
26
  class CustomFormatter(logging.Formatter):
27
  """Custom formatter with colors and better formatting"""
28
  grey = "\x1b[38;21m"
 
110
 
111
 
112
  PATTERN_WEIGHTS = {
113
+ "recovery": 0.7,
114
  "control": 1.4,
115
  "gaslighting": 1.3,
116
  "guilt tripping": 1.2,
 
329
  logger.error(f"Error determining risk stage: {e}")
330
  return 1
331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  @spaces.GPU
333
  def compute_abuse_score(matched_scores, sentiment):
334
  """Compute abuse score from matched patterns and sentiment"""
 
391
  return 0.0, [], [], {"label": "none"}, 1, 0.0, None
392
 
393
  # Check for explicit abuse
394
+ explicit_abuse_words = ['fuck', 'bitch', 'shit', 'ass', 'dick']
395
  explicit_abuse = any(word in text.lower() for word in explicit_abuse_words)
396
  logger.debug(f"Explicit abuse detected: {explicit_abuse}")
397
 
 
420
  if explicit_abuse:
421
  threshold_labels.append("insults")
422
  logger.debug("\nForced inclusion of 'insults' due to explicit abuse")
423
+
424
  for label, score in sorted_predictions:
 
 
425
  base_threshold = thresholds.get(label, 0.25)
426
  if explicit_abuse:
427
  base_threshold *= 0.5
428
  if score > base_threshold:
429
  if label not in threshold_labels:
430
  threshold_labels.append(label)
431
+
432
+ logger.debug("\nLabels that passed thresholds:", threshold_labels)
433
 
434
+ # Calculate matched scores
 
 
 
 
 
 
435
  matched_scores = []
436
  for label in threshold_labels:
 
 
437
  score = raw_scores[LABELS.index(label)]
438
  weight = PATTERN_WEIGHTS.get(label, 1.0)
439
  if explicit_abuse and label == "insults":
440
  weight *= 1.5
441
  matched_scores.append((label, score, weight))
442
 
 
 
 
 
 
 
 
 
 
443
  # Get sentiment
444
  sent_inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
445
  sent_inputs = {k: v.to(device) for k, v in sent_inputs.items()}
 
452
  abuse_score = compute_abuse_score(matched_scores, sentiment)
453
  if explicit_abuse:
454
  abuse_score = max(abuse_score, 70.0)
455
+
 
 
 
 
456
  # Get DARVO score
457
  darvo_score = predict_darvo_score(text)
458
 
459
  # Get tone using emotion-based approach
460
  tone_tag = get_emotional_tone_tag(text, sentiment, threshold_labels, abuse_score)
461
+ # Check for the specific combination
462
+ highest_pattern = max(matched_scores, key=lambda x: x[1])[0] if matched_scores else None # Get highest pattern
 
463
  if sentiment == "supportive" and tone_tag == "neutral" and highest_pattern == "obscure language":
464
  logger.debug("Message classified as likely non-abusive (supportive, neutral, and obscure language). Returning low risk.")
465
  return 0.0, [], [], {"label": "supportive"}, 1, 0.0, "neutral" # Return non-abusive values
466
 
467
  # Set stage
468
+ stage = 2 if explicit_abuse or abuse_score > 70 else 1
469
 
470
  logger.debug("=== DEBUG END ===\n")
471
 
 
474
  except Exception as e:
475
  logger.error(f"Error in analyze_single_message: {e}")
476
  return 0.0, [], [], {"label": "error"}, 1, 0.0, None
 
477
  def generate_abuse_score_chart(dates, scores, patterns):
478
  """Generate a timeline chart of abuse scores"""
479
  try:
 
664
  logger.debug(f" • {severity} | {label}: {score:.3f} (weight: {weight})")
665
  else:
666
  logger.debug("\n✓ No abuse patterns detected")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
667
 
 
668
  # Extract scores and metadata
669
  abuse_scores = [r[0][0] for r in results]
670
  stages = [r[0][4] for r in results]
 
831
  "Low"
832
  )
833
  logger.debug(f"\n⚠️ Final Escalation Risk: {escalation_risk}")
 
 
 
834
  # Generate Output Text
835
  logger.debug("\n📝 GENERATING OUTPUT")
836
  logger.debug("=" * 50)