test_1 / app.py
aeresd's picture
Update app.py
826a9bc verified
raw
history blame
4.55 kB
# 新增:预定义冒犯性类别映射(根据雷达图需求)
OFFENSE_CATEGORIES = {
"Insult": ["侮辱", "贬低", "讽刺"],
"Abuse": ["辱骂", "攻击性语言", "脏话"],
"Discrimination": ["歧视性言论", "种族歧视", "性别歧视"],
"Hate Speech": ["仇恨言论", "暴力煽动"],
"Vulgarity": ["低俗用语", "色情暗示"]
}
# 修改分类函数以支持多维度分析
def classify_text_with_categories(text: str):
results = classifier(text)
category_scores = {category: 0 for category in OFFENSE_CATEGORIES}
# 多维度评分
for res in results:
label = res["label"]
score = res["score"]
for cat, keywords in OFFENSE_CATEGORIES.items():
if any(kw in label.lower() for kw in keywords):
category_scores[cat] += score
# 单词级分析
word_analysis = []
for word in text.split():
try:
res = classifier(word)[0]
word_analysis.append({
"word": word,
"main_label": res["label"],
"main_score": res["score"],
"offense_category": next(
(cat for cat, keywords in OFFENSE_CATEGORIES.items()
if any(kw in res["label"].lower() for kw in keywords)),
"Other"
)
})
except:
continue
return {
"translations": text,
"overall": results[0],
"categories": category_scores,
"word_analysis": word_analysis
}
# 优化后的分类处理
if st.button("🚦 Analyze Text"):
with st.spinner("🔍 Processing..."):
try:
# 处理文本输入
text_input = text if text else ocr_text
analysis = classify_text_with_categories(text_input)
# 更新历史记录
st.session_state.history.append({
"original": text_input,
"translated": analysis["translations"],
"overall": analysis["overall"],
"categories": analysis["categories"],
"word_analysis": analysis["word_analysis"]
})
# 展示核心结果
st.markdown("**Main Prediction:**")
st.metric("Label", analysis["overall"]["label"],
delta=f"{analysis['overall']['score']:.2%} Confidence")
# 新增:类别分布展示
st.markdown("### 📊 Offense Category Breakdown")
category_data = [{"Category": k, "Score": v} for k, v in analysis["categories"].items()]
fig = px.bar(category_data, x="Category", y="Score",
title="Category Contribution",
labels={"Score": "Probability"})
st.plotly_chart(fig)
except Exception as e:
st.error(f"❌ Error: {str(e)}")
# 优化后的雷达图生成
if st.session_state.history:
# 聚合所有历史记录的类别数据
radar_data = {cat: [] for cat in OFFENSE_CATEGORIES}
for entry in st.session_state.history:
for cat, score in entry["categories"].items():
radar_data[cat].append(score)
# 计算平均得分
avg_scores = {cat: sum(scores)/len(scores) if scores else 0
for cat, scores in radar_data.items()}
# 构建雷达图
fig = px.line_polar(
pd.DataFrame(avg_scores, index=OFFENSE_CATEGORIES).reset_index(),
r='index', theta='OFFENSE_CATEGORIES',
line_close=True, title="📉 Offense Risk Radar Chart"
)
fig.update_traces(line_color='#FF4B4B')
st.plotly_chart(fig)
# 优化后的单词级分析
if st.session_state.history:
# 聚合单词级数据
all_words = []
for entry in st.session_state.history:
all_words.extend(entry["word_analysis"])
# 生成词云数据
word_counts = pd.DataFrame(all_words).groupby('word').agg({
'main_score': 'mean',
'offense_category': lambda x: x.mode()[0]
}).reset_index().sort_values('main_score', ascending=False)
# 交互式词云展示
st.markdown("### 🧩 Offensive Word Analysis")
if not word_counts.empty:
top_words = word_counts.head(10)
fig = px.bar(top_words, x="word", y="main_score",
color="offense_category",
title="Top Offensive Words by Score")
st.plotly_chart(fig)
else:
st.info("No offensive words detected")