aeresd commited on
Commit
88f8bd0
·
verified ·
1 Parent(s): bd9feeb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -147
app.py CHANGED
@@ -6,16 +6,7 @@ import pytesseract
6
  import pandas as pd
7
  import plotly.express as px
8
 
9
- # ✅ 新增维度定义
10
- OFFENSIVE_CATEGORIES = {
11
- "Insult": ["蠢货", "白痴", "废物"],
12
- "Abuse": ["去死", "打死", "宰了你"],
13
- "Discrimination": ["女司机", "娘娘腔", "黑鬼"],
14
- "HateSpeech": ["灭族", "屠杀", "灭绝"],
15
- "Vulgarity": ["艹", "sb", "尼玛"]
16
- }
17
-
18
- # ✅ 模型初始化(保持原有结构)
19
  emoji_model_id = "JenniferHJF/qwen1.5-emoji-finetuned"
20
  emoji_tokenizer = AutoTokenizer.from_pretrained(emoji_model_id, trust_remote_code=True)
21
  emoji_model = AutoModelForCausalLM.from_pretrained(
@@ -25,34 +16,71 @@ emoji_model = AutoModelForCausalLM.from_pretrained(
25
  ).to("cuda" if torch.cuda.is_available() else "cpu")
26
  emoji_model.eval()
27
 
 
28
  model_options = {
29
  "Toxic-BERT": "unitary/toxic-bert",
30
  "Roberta Offensive": "cardiffnlp/twitter-roberta-base-offensive",
31
  "BERT Emotion": "bhadresh-savani/bert-base-go-emotion"
32
  }
33
 
34
- # ✅ 动态评分算法
35
- def dynamic_scoring(text: str, classifier):
36
- scores = {k: 0.0 for k in OFFENSIVE_CATEGORIES}
37
-
38
- for category, keywords in OFFENSIVE_CATEGORIES.items():
39
- for kw in keywords:
40
- if kw in text:
41
- scores[category] += 0.3
42
-
43
- words = text.split()
44
- for word in words:
45
- try:
46
- res = classifier(word)[0]
47
- if res["label"] in scores:
48
- scores[res["label"]] += res["score"] * 0.7
49
- except: pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- max_score = max(scores.values()) or 1
52
- return {k: round(v/max_score, 2) for k,v in scores.items()}
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- # ✅ 分类函数改造
 
 
 
 
55
  def classify_emoji_text(text: str):
 
56
  prompt = f"输入:{text}\n输出:"
57
  input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
58
  with torch.no_grad():
@@ -60,67 +88,49 @@ def classify_emoji_text(text: str):
60
  decoded = emoji_tokenizer.decode(output_ids[0], skip_special_tokens=True)
61
  translated_text = decoded.split("输出:")[-1].strip() if "输出:" in decoded else decoded.strip()
62
 
63
- result = classifier(translated_text)[0]
64
- label = result["label"]
65
- score = result["score"]
66
- reasoning = f"The sentence was flagged as '{label}' due to potentially offensive phrases."
67
 
68
- # 新增维度分析
69
- category_scores = dynamic_scoring(translated_text, classifier)
 
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  st.session_state.history.append({
72
  "text": text,
73
  "translated": translated_text,
74
- "label": label,
75
- "score": score,
76
- "reason": reasoning,
77
- "scores": category_scores
78
- })
79
- return translated_text, label, score, reasoning, category_scores
80
-
81
- # ✅ 可视化生成函数
82
- def generate_radar_chart(scores_dict: dict):
83
- radar_df = pd.DataFrame({
84
- "Category": list(scores_dict.keys()),
85
- "Score": list(scores_dict.values())
86
  })
87
 
88
- fig = px.line_polar(
89
- radar_df,
90
- r='Score',
91
- theta='Category',
92
- line_close=True,
93
- color_discrete_sequence=['#FF6B6B'],
94
- title="🛡️ Multi-Dimensional Offensive Analysis"
95
- )
96
- fig.update_layout(
97
- polar=dict(
98
- radialaxis=dict(
99
- visible=True,
100
- range=[0, 1],
101
- tickvals=[0, 0.3, 0.7, 1],
102
- ticktext=["Safe", "Caution", "Risk", "Danger"]
103
- )),
104
- showlegend=False
105
- )
106
- return fig
107
-
108
- # ✅ 页面配置(保持原有结构)
109
- st.set_page_config(page_title="Emoji Offensive Text Detector", page_icon="🚨", layout="wide")
110
-
111
- with st.sidebar:
112
- st.header("🧠 Configuration")
113
- selected_model = st.selectbox("Choose classification model", list(model_options.keys()))
114
- selected_model_id = model_options[selected_model]
115
- classifier = pipeline("text-classification", model=selected_model_id, device=0 if torch.cuda.is_available() else -1)
116
-
117
- if "history" not in st.session_state:
118
- st.session_state.history = []
119
 
120
- # 主页面逻辑
121
  st.title("🚨 Emoji Offensive Text Detector & Analysis Dashboard")
122
 
123
- # 文本输入
124
  st.subheader("1. 输入与分类")
125
  default_text = "你是🐷"
126
  text = st.text_area("Enter sentence with emojis:", value=default_text, height=150)
@@ -128,92 +138,99 @@ text = st.text_area("Enter sentence with emojis:", value=default_text, height=15
128
  if st.button("🚦 Analyze Text"):
129
  with st.spinner("🔍 Processing..."):
130
  try:
131
- translated, label, score, reason, category_scores = classify_emoji_text(text)
132
- # 展示基础结果
133
  st.markdown("**Translated sentence:**")
134
  st.code(translated, language="text")
135
- # 展示雷达图
136
- st.plotly_chart(generate_radar_chart(category_scores))
137
-
138
- # 图片上传与 OCR
 
 
 
 
 
 
 
 
 
 
 
 
139
  st.markdown("---")
140
  st.subheader("2. 图片 OCR & 分类")
141
  uploaded_file = st.file_uploader("Upload an image (JPG/PNG)", type=["jpg","jpeg","png"])
142
  if uploaded_file:
143
  image = Image.open(uploaded_file)
144
  st.image(image, caption="Uploaded Screenshot", use_column_width=True)
 
145
  with st.spinner("🧠 Extracting text via OCR..."):
146
  ocr_text = pytesseract.image_to_string(image, lang="chi_sim+eng").strip()
 
147
  if ocr_text:
148
  st.markdown("**Extracted Text:**")
149
  st.code(ocr_text)
150
- translated, label, score, reason = classify_emoji_text(ocr_text)
151
- st.markdown("**Translated sentence:**")
152
- st.code(translated, language="text")
153
- st.markdown(f"**Prediction:** {label}")
154
- st.markdown(f"**Confidence Score:** {score:.2%}")
155
- st.markdown("**Model Explanation:**")
156
- st.info(reason)
157
  else:
158
- st.info("⚠️ No text detected in the image.")
159
 
160
- # 分析仪表盘
161
  st.markdown("---")
162
- st.subheader("3. Violation Analysis Dashboard")
163
  if st.session_state.history:
164
- # 展示历史记录
165
- df = pd.DataFrame(st.session_state.history)
166
- st.markdown("### 🧾 Offensive Terms & Suggestions")
167
- for item in st.session_state.history:
168
- st.markdown(f"- 🔹 **Input:** {item['text']}")
169
- st.markdown(f" - ✨ **Translated:** {item['translated']}")
170
- st.markdown(f" - ❗ **Label:** {item['label']} with **{item['score']:.2%}** confidence")
171
- st.markdown(f" - 🔧 **Suggestion:** {item['reason']}")
172
-
173
  # 雷达图
 
174
  radar_df = pd.DataFrame({
175
- "Category": ["Insult","Abuse","Discrimination","Hate Speech","Vulgarity"],
176
- "Score": [0.7,0.4,0.3,0.5,0.6]
177
  })
178
- radar_fig = px.line_polar(radar_df, r='Score', theta='Category', line_close=True, title="⚠️ Risk Radar by Category")
179
- radar_fig.update_traces(line_color='black')
180
- st.plotly_chart(radar_fig)
181
-
182
- # —— 新增:单词级冒犯性相关性分析 —— #
183
- st.markdown("### 🧬 Word-level Offensive Correlation")
184
-
185
- # 取最近一次翻译文本,按空格拆分单词
186
- last_translated_text = st.session_state.history[-1]["translated"]
187
- words = last_translated_text.split()
188
-
189
- # 对每个单词进行分类并收集分数
190
- word_scores = []
191
- for word in words:
192
- try:
193
- res = classifier(word)[0]
194
- word_scores.append({
195
- "Word": word,
196
- "Label": res["label"],
197
- "Score": res["score"]
198
- })
199
- except Exception:
200
- continue
201
-
202
- if word_scores:
203
- word_df = pd.DataFrame(word_scores)
204
- word_df = word_df.sort_values(by="Score", ascending=False).reset_index(drop=True)
205
-
206
- max_display = 5
207
- # Streamlit 1.22+ 支持 st.toggle,若版本不支持可改用 checkbox
208
- show_more = st.toggle("Show more words", value=False)
209
-
210
- display_df = word_df if show_more else word_df.head(max_display)
211
- # 隐藏边框并渲染 HTML 表格
212
- st.markdown(
213
- display_df.to_html(index=False, border=0),
214
- unsafe_allow_html=True
215
- )
216
  else:
217
- st.info(" No word-level analysis available.")
 
 
 
 
 
 
 
 
 
 
218
  else:
219
- st.info("⚠️ No classification data available yet.")
 
6
  import pandas as pd
7
  import plotly.express as px
8
 
9
+ # ✅ Step 1: Emoji翻译模型
 
 
 
 
 
 
 
 
 
10
  emoji_model_id = "JenniferHJF/qwen1.5-emoji-finetuned"
11
  emoji_tokenizer = AutoTokenizer.from_pretrained(emoji_model_id, trust_remote_code=True)
12
  emoji_model = AutoModelForCausalLM.from_pretrained(
 
16
  ).to("cuda" if torch.cuda.is_available() else "cpu")
17
  emoji_model.eval()
18
 
19
+ # ✅ Step 2: 分类模型配置
20
  model_options = {
21
  "Toxic-BERT": "unitary/toxic-bert",
22
  "Roberta Offensive": "cardiffnlp/twitter-roberta-base-offensive",
23
  "BERT Emotion": "bhadresh-savani/bert-base-go-emotion"
24
  }
25
 
26
+ # 雷达图分类映射系统
27
+ category_system = {
28
+ "Insult": ["侮辱", "贬低", "人身攻击"],
29
+ "Abuse": ["威胁", "暴力", "骚扰"],
30
+ "Discrimination": ["种族", "性别", "宗教"],
31
+ "Hate Speech": ["仇恨", "极端言论"],
32
+ "Vulgarity": ["脏话", "低俗", "性暗示"]
33
+ }
34
+
35
+ # 模型到分类系统的映射
36
+ model_category_map = {
37
+ "Toxic-BERT": {
38
+ "toxic": ["Vulgarity"],
39
+ "severe_toxic": ["Abuse"],
40
+ "obscene": ["Vulgarity"],
41
+ "threat": ["Abuse", "Hate Speech"],
42
+ "insult": ["Insult"],
43
+ "identity_hate": ["Discrimination", "Hate Speech"]
44
+ },
45
+ "Roberta Offensive": {
46
+ "offensive": ["Insult", "Abuse"]
47
+ },
48
+ "BERT Emotion": {
49
+ "anger": ["Abuse"],
50
+ "disgust": ["Vulgarity"]
51
+ }
52
+ }
53
+
54
+ # ✅ 页面配置
55
+ st.set_page_config(page_title="Emoji Offensive Text Detector", page_icon="🚨", layout="wide")
56
+
57
+ # ✅ 侧边栏配置
58
+ with st.sidebar:
59
+ st.header("🧠 Configuration")
60
+ selected_model = st.selectbox("Choose classification model", list(model_options.keys()))
61
+ selected_model_id = model_options[selected_model]
62
 
63
+ # 动态调整分类器参数
64
+ classifier_config = {
65
+ "device": 0 if torch.cuda.is_available() else -1,
66
+ "top_k": None if selected_model == "Toxic-BERT" else 1
67
+ }
68
+ if selected_model == "Toxic-BERT":
69
+ classifier_config["function_to_apply"] = "sigmoid"
70
+
71
+ classifier = pipeline(
72
+ "text-classification",
73
+ model=selected_model_id,
74
+ **classifier_config
75
+ )
76
 
77
+ # 初始化历史记录
78
+ if "history" not in st.session_state:
79
+ st.session_state.history = []
80
+
81
+ # ✅ 核心分类函数
82
  def classify_emoji_text(text: str):
83
+ # Emoji翻译
84
  prompt = f"输入:{text}\n输出:"
85
  input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
86
  with torch.no_grad():
 
88
  decoded = emoji_tokenizer.decode(output_ids[0], skip_special_tokens=True)
89
  translated_text = decoded.split("输出:")[-1].strip() if "输出:" in decoded else decoded.strip()
90
 
91
+ # 整体分类
92
+ main_result = classifier(translated_text)[0]
 
 
93
 
94
+ # 元素级分析
95
+ elements = translated_text.split()
96
+ element_analysis = []
97
+ radar_scores = {category: 0.0 for category in category_system}
98
 
99
+ for elem in elements:
100
+ try:
101
+ results = classifier(elem)
102
+ for res in results:
103
+ for model_label in model_category_map.get(selected_model, {}):
104
+ if res["label"] == model_label:
105
+ score = res["score"]
106
+ for category in model_category_map[selected_model][model_label]:
107
+ if score > radar_scores[category]:
108
+ radar_scores[category] = score
109
+ element_analysis.append({
110
+ "Element": elem,
111
+ "Original": text.split()[elements.index(elem)] if len(text.split()) > elements.index(elem) else "",
112
+ "Category": category,
113
+ "Score": score
114
+ })
115
+ except Exception as e:
116
+ continue
117
+
118
+ # 记录历史
119
  st.session_state.history.append({
120
  "text": text,
121
  "translated": translated_text,
122
+ "label": main_result["label"],
123
+ "score": main_result["score"],
124
+ "elements": element_analysis,
125
+ "radar": radar_scores
 
 
 
 
 
 
 
 
126
  })
127
 
128
+ return translated_text, main_result["label"], main_result["score"], radar_scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ # ✅ 主界面
131
  st.title("🚨 Emoji Offensive Text Detector & Analysis Dashboard")
132
 
133
+ # 文本输入模块
134
  st.subheader("1. 输入与分类")
135
  default_text = "你是🐷"
136
  text = st.text_area("Enter sentence with emojis:", value=default_text, height=150)
 
138
  if st.button("🚦 Analyze Text"):
139
  with st.spinner("🔍 Processing..."):
140
  try:
141
+ translated, label, score, radar = classify_emoji_text(text)
142
+
143
  st.markdown("**Translated sentence:**")
144
  st.code(translated, language="text")
145
+
146
+ col1, col2 = st.columns(2)
147
+ with col1:
148
+ st.metric("Prediction", f"{label} 🔴" if score > 0.5 else f"{label} 🟢")
149
+ with col2:
150
+ st.metric("Confidence", f"{score:.2%}")
151
+
152
+ st.markdown("**Model Explanation:**")
153
+ st.info(f"文本被识别为「{label}」,建议检查以下内容:")
154
+ for cat, score in radar.items():
155
+ if score > 0.5:
156
+ st.markdown(f"- ❗ **{cat}** 风险 ({score:.2%})")
157
+ except Exception as e:
158
+ st.error(f"❌ Error: {e}")
159
+
160
+ # 图片分析模块
161
  st.markdown("---")
162
  st.subheader("2. 图片 OCR & 分类")
163
  uploaded_file = st.file_uploader("Upload an image (JPG/PNG)", type=["jpg","jpeg","png"])
164
  if uploaded_file:
165
  image = Image.open(uploaded_file)
166
  st.image(image, caption="Uploaded Screenshot", use_column_width=True)
167
+
168
  with st.spinner("🧠 Extracting text via OCR..."):
169
  ocr_text = pytesseract.image_to_string(image, lang="chi_sim+eng").strip()
170
+
171
  if ocr_text:
172
  st.markdown("**Extracted Text:**")
173
  st.code(ocr_text)
174
+
175
+ try:
176
+ translated, label, score, radar = classify_emoji_text(ocr_text)
177
+ st.markdown(f"**Prediction:** {label} ({score:.2%})")
178
+ except Exception as e:
179
+ st.error(f"OCR分析错误: {e}")
 
180
  else:
181
+ st.info("⚠️ 未检测到文字内容")
182
 
183
+ # 数据分析仪表盘
184
  st.markdown("---")
185
+ st.subheader("3. 风险分析仪表盘")
186
  if st.session_state.history:
187
+ latest = st.session_state.history[-1]
188
+
 
 
 
 
 
 
 
189
  # 雷达图
190
+ st.markdown("### ⚠️ 风险雷达图")
191
  radar_df = pd.DataFrame({
192
+ "Category": latest["radar"].keys(),
193
+ "Score": latest["radar"].values()
194
  })
195
+ fig = px.line_polar(
196
+ radar_df,
197
+ r="Score",
198
+ theta="Category",
199
+ line_close=True,
200
+ range_r=[0,1],
201
+ template="plotly_dark"
202
+ )
203
+ fig.update_traces(fill="toself", line_color="red")
204
+ st.plotly_chart(fig, use_container_width=True)
205
+
206
+ # 元素贡献分析
207
+ st.markdown("### 🧩 风险元素分解表")
208
+ if latest["elements"]:
209
+ element_df = pd.DataFrame(latest["elements"])
210
+ element_df = element_df.sort_values(by=["Score", "Category"], ascending=False)
211
+
212
+ # 分组展示
213
+ for category in category_system:
214
+ cat_df = element_df[element_df["Category"] == category]
215
+ if not cat_df.empty:
216
+ with st.expander(f"{category} 风险元素 ({len(cat_df)}项)"):
217
+ st.dataframe(
218
+ cat_df[["Element", "Original", "Score"]]
219
+ .style.highlight_between(subset="Score", color="#ffcccc"),
220
+ use_container_width=True,
221
+ hide_index=True
222
+ )
 
 
 
 
 
 
 
 
 
 
223
  else:
224
+ st.info(" 未检测到高风险元素")
225
+
226
+ # 历史记录
227
+ st.markdown("### 📜 分析历史")
228
+ history_df = pd.DataFrame(st.session_state.history)
229
+ st.dataframe(
230
+ history_df[["text", "label", "score"]]
231
+ .style.applymap(lambda x: "color: red" if x == "OFFENSIVE" else ""),
232
+ use_container_width=True,
233
+ hide_index=True
234
+ )
235
  else:
236
+ st.info("🕑 等待首次分析结果...")