File size: 4,548 Bytes
826a9bc
 
 
 
 
 
 
5a8b969
444b661
826a9bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8b7aaa
826a9bc
 
 
 
 
 
a8b7aaa
826a9bc
11355eb
 
 
826a9bc
 
 
 
 
 
 
 
 
 
 
857cce7
826a9bc
 
 
 
 
 
 
 
 
 
 
 
 
857cce7
826a9bc
 
857cce7
826a9bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
857cce7
826a9bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
857cce7
826a9bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# 新增:预定义冒犯性类别映射(根据雷达图需求)
OFFENSE_CATEGORIES = {
    "Insult": ["侮辱", "贬低", "讽刺"],
    "Abuse": ["辱骂", "攻击性语言", "脏话"],
    "Discrimination": ["歧视性言论", "种族歧视", "性别歧视"],
    "Hate Speech": ["仇恨言论", "暴力煽动"],
    "Vulgarity": ["低俗用语", "色情暗示"]
}

# 修改分类函数以支持多维度分析
def classify_text_with_categories(text: str):
    results = classifier(text)
    category_scores = {category: 0 for category in OFFENSE_CATEGORIES}
    
    # 多维度评分
    for res in results:
        label = res["label"]
        score = res["score"]
        for cat, keywords in OFFENSE_CATEGORIES.items():
            if any(kw in label.lower() for kw in keywords):
                category_scores[cat] += score

    # 单词级分析
    word_analysis = []
    for word in text.split():
        try:
            res = classifier(word)[0]
            word_analysis.append({
                "word": word,
                "main_label": res["label"],
                "main_score": res["score"],
                "offense_category": next(
                    (cat for cat, keywords in OFFENSE_CATEGORIES.items() 
                     if any(kw in res["label"].lower() for kw in keywords)), 
                    "Other"
                )
            })
        except:
            continue

    return {
        "translations": text,
        "overall": results[0],
        "categories": category_scores,
        "word_analysis": word_analysis
    }

# 优化后的分类处理
if st.button("🚦 Analyze Text"):
    with st.spinner("🔍 Processing..."):
        try:
            # 处理文本输入
            text_input = text if text else ocr_text
            analysis = classify_text_with_categories(text_input)
            
            # 更新历史记录
            st.session_state.history.append({
                "original": text_input,
                "translated": analysis["translations"],
                "overall": analysis["overall"],
                "categories": analysis["categories"],
                "word_analysis": analysis["word_analysis"]
            })
            
            # 展示核心结果
            st.markdown("**Main Prediction:**")
            st.metric("Label", analysis["overall"]["label"], 
                     delta=f"{analysis['overall']['score']:.2%} Confidence")
            
            # 新增:类别分布展示
            st.markdown("### 📊 Offense Category Breakdown")
            category_data = [{"Category": k, "Score": v} for k, v in analysis["categories"].items()]
            fig = px.bar(category_data, x="Category", y="Score", 
                        title="Category Contribution", 
                        labels={"Score": "Probability"})
            st.plotly_chart(fig)

        except Exception as e:
            st.error(f"❌ Error: {str(e)}")

# 优化后的雷达图生成
if st.session_state.history:
    # 聚合所有历史记录的类别数据
    radar_data = {cat: [] for cat in OFFENSE_CATEGORIES}
    for entry in st.session_state.history:
        for cat, score in entry["categories"].items():
            radar_data[cat].append(score)
    
    # 计算平均得分
    avg_scores = {cat: sum(scores)/len(scores) if scores else 0 
                 for cat, scores in radar_data.items()}
    
    # 构建雷达图
    fig = px.line_polar(
        pd.DataFrame(avg_scores, index=OFFENSE_CATEGORIES).reset_index(),
        r='index', theta='OFFENSE_CATEGORIES',
        line_close=True, title="📉 Offense Risk Radar Chart"
    )
    fig.update_traces(line_color='#FF4B4B')
    st.plotly_chart(fig)

# 优化后的单词级分析
if st.session_state.history:
    # 聚合单词级数据
    all_words = []
    for entry in st.session_state.history:
        all_words.extend(entry["word_analysis"])
    
    # 生成词云数据
    word_counts = pd.DataFrame(all_words).groupby('word').agg({
        'main_score': 'mean',
        'offense_category': lambda x: x.mode()[0]
    }).reset_index().sort_values('main_score', ascending=False)
    
    # 交互式词云展示
    st.markdown("### 🧩 Offensive Word Analysis")
    if not word_counts.empty:
        top_words = word_counts.head(10)
        fig = px.bar(top_words, x="word", y="main_score",
                    color="offense_category",
                    title="Top Offensive Words by Score")
        st.plotly_chart(fig)
    else:
        st.info("No offensive words detected")