aeresd commited on
Commit
9284803
·
verified ·
1 Parent(s): 88f8bd0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -153
app.py CHANGED
@@ -6,7 +6,7 @@ import pytesseract
6
  import pandas as pd
7
  import plotly.express as px
8
 
9
- # ✅ Step 1: Emoji翻译模型
10
  emoji_model_id = "JenniferHJF/qwen1.5-emoji-finetuned"
11
  emoji_tokenizer = AutoTokenizer.from_pretrained(emoji_model_id, trust_remote_code=True)
12
  emoji_model = AutoModelForCausalLM.from_pretrained(
@@ -16,71 +16,29 @@ emoji_model = AutoModelForCausalLM.from_pretrained(
16
  ).to("cuda" if torch.cuda.is_available() else "cpu")
17
  emoji_model.eval()
18
 
19
- # ✅ Step 2: 分类模型配置
20
  model_options = {
21
  "Toxic-BERT": "unitary/toxic-bert",
22
  "Roberta Offensive": "cardiffnlp/twitter-roberta-base-offensive",
23
  "BERT Emotion": "bhadresh-savani/bert-base-go-emotion"
24
  }
25
 
26
- # 雷达图分类映射系统
27
- category_system = {
28
- "Insult": ["侮辱", "贬低", "人身攻击"],
29
- "Abuse": ["威胁", "暴力", "骚扰"],
30
- "Discrimination": ["种族", "性别", "宗教"],
31
- "Hate Speech": ["仇恨", "极端言论"],
32
- "Vulgarity": ["脏话", "低俗", "性暗示"]
33
- }
34
-
35
- # 模型到分类系统的映射
36
- model_category_map = {
37
- "Toxic-BERT": {
38
- "toxic": ["Vulgarity"],
39
- "severe_toxic": ["Abuse"],
40
- "obscene": ["Vulgarity"],
41
- "threat": ["Abuse", "Hate Speech"],
42
- "insult": ["Insult"],
43
- "identity_hate": ["Discrimination", "Hate Speech"]
44
- },
45
- "Roberta Offensive": {
46
- "offensive": ["Insult", "Abuse"]
47
- },
48
- "BERT Emotion": {
49
- "anger": ["Abuse"],
50
- "disgust": ["Vulgarity"]
51
- }
52
- }
53
-
54
  # ✅ 页面配置
55
  st.set_page_config(page_title="Emoji Offensive Text Detector", page_icon="🚨", layout="wide")
56
 
57
- # ✅ 侧边栏配置
58
  with st.sidebar:
59
  st.header("🧠 Configuration")
60
  selected_model = st.selectbox("Choose classification model", list(model_options.keys()))
61
  selected_model_id = model_options[selected_model]
62
-
63
- # 动态调整分类器参数
64
- classifier_config = {
65
- "device": 0 if torch.cuda.is_available() else -1,
66
- "top_k": None if selected_model == "Toxic-BERT" else 1
67
- }
68
- if selected_model == "Toxic-BERT":
69
- classifier_config["function_to_apply"] = "sigmoid"
70
-
71
- classifier = pipeline(
72
- "text-classification",
73
- model=selected_model_id,
74
- **classifier_config
75
- )
76
 
77
  # 初始化历史记录
78
  if "history" not in st.session_state:
79
  st.session_state.history = []
80
 
81
- # ✅ 核心分类函数
82
  def classify_emoji_text(text: str):
83
- # Emoji翻译
84
  prompt = f"输入:{text}\n输出:"
85
  input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
86
  with torch.no_grad():
@@ -88,49 +46,27 @@ def classify_emoji_text(text: str):
88
  decoded = emoji_tokenizer.decode(output_ids[0], skip_special_tokens=True)
89
  translated_text = decoded.split("输出:")[-1].strip() if "输出:" in decoded else decoded.strip()
90
 
91
- # 整体分类
92
- main_result = classifier(translated_text)[0]
93
-
94
- # 元素级分析
95
- elements = translated_text.split()
96
- element_analysis = []
97
- radar_scores = {category: 0.0 for category in category_system}
98
-
99
- for elem in elements:
100
- try:
101
- results = classifier(elem)
102
- for res in results:
103
- for model_label in model_category_map.get(selected_model, {}):
104
- if res["label"] == model_label:
105
- score = res["score"]
106
- for category in model_category_map[selected_model][model_label]:
107
- if score > radar_scores[category]:
108
- radar_scores[category] = score
109
- element_analysis.append({
110
- "Element": elem,
111
- "Original": text.split()[elements.index(elem)] if len(text.split()) > elements.index(elem) else "",
112
- "Category": category,
113
- "Score": score
114
- })
115
- except Exception as e:
116
- continue
117
 
118
- # 记录历史
119
  st.session_state.history.append({
120
  "text": text,
121
  "translated": translated_text,
122
- "label": main_result["label"],
123
- "score": main_result["score"],
124
- "elements": element_analysis,
125
- "radar": radar_scores
126
  })
127
-
128
- return translated_text, main_result["label"], main_result["score"], radar_scores
129
 
130
- # ✅ 主界面
131
  st.title("🚨 Emoji Offensive Text Detector & Analysis Dashboard")
132
 
133
- # 文本输入模块
134
  st.subheader("1. 输入与分类")
135
  default_text = "你是🐷"
136
  text = st.text_area("Enter sentence with emojis:", value=default_text, height=150)
@@ -138,99 +74,95 @@ text = st.text_area("Enter sentence with emojis:", value=default_text, height=15
138
  if st.button("🚦 Analyze Text"):
139
  with st.spinner("🔍 Processing..."):
140
  try:
141
- translated, label, score, radar = classify_emoji_text(text)
142
-
143
  st.markdown("**Translated sentence:**")
144
  st.code(translated, language="text")
145
-
146
- col1, col2 = st.columns(2)
147
- with col1:
148
- st.metric("Prediction", f"{label} 🔴" if score > 0.5 else f"{label} 🟢")
149
- with col2:
150
- st.metric("Confidence", f"{score:.2%}")
151
-
152
  st.markdown("**Model Explanation:**")
153
- st.info(f"文本被识别为「{label}」,建议检查以下内容:")
154
- for cat, score in radar.items():
155
- if score > 0.5:
156
- st.markdown(f"- ❗ **{cat}** 风险 ({score:.2%})")
157
  except Exception as e:
158
- st.error(f"❌ Error: {e}")
159
 
160
- # 图片分析模块
161
  st.markdown("---")
162
  st.subheader("2. 图片 OCR & 分类")
163
  uploaded_file = st.file_uploader("Upload an image (JPG/PNG)", type=["jpg","jpeg","png"])
164
  if uploaded_file:
165
  image = Image.open(uploaded_file)
166
  st.image(image, caption="Uploaded Screenshot", use_column_width=True)
167
-
168
  with st.spinner("🧠 Extracting text via OCR..."):
169
  ocr_text = pytesseract.image_to_string(image, lang="chi_sim+eng").strip()
170
-
171
  if ocr_text:
172
  st.markdown("**Extracted Text:**")
173
  st.code(ocr_text)
174
-
175
- try:
176
- translated, label, score, radar = classify_emoji_text(ocr_text)
177
- st.markdown(f"**Prediction:** {label} ({score:.2%})")
178
- except Exception as e:
179
- st.error(f"OCR分析错误: {e}")
 
180
  else:
181
- st.info("⚠️ 未检测到文字内容")
182
 
183
- # 数据分析仪表盘
184
  st.markdown("---")
185
- st.subheader("3. 风险分析仪表盘")
186
  if st.session_state.history:
187
- latest = st.session_state.history[-1]
188
-
 
 
 
 
 
 
 
189
  # 雷达图
190
- st.markdown("### ⚠️ 风险雷达图")
191
  radar_df = pd.DataFrame({
192
- "Category": latest["radar"].keys(),
193
- "Score": latest["radar"].values()
194
  })
195
- fig = px.line_polar(
196
- radar_df,
197
- r="Score",
198
- theta="Category",
199
- line_close=True,
200
- range_r=[0,1],
201
- template="plotly_dark"
202
- )
203
- fig.update_traces(fill="toself", line_color="red")
204
- st.plotly_chart(fig, use_container_width=True)
205
-
206
- # 元素贡献分析
207
- st.markdown("### 🧩 风险元素分解表")
208
- if latest["elements"]:
209
- element_df = pd.DataFrame(latest["elements"])
210
- element_df = element_df.sort_values(by=["Score", "Category"], ascending=False)
211
-
212
- # 分组展示
213
- for category in category_system:
214
- cat_df = element_df[element_df["Category"] == category]
215
- if not cat_df.empty:
216
- with st.expander(f"{category} 风险元素 ({len(cat_df)}项)"):
217
- st.dataframe(
218
- cat_df[["Element", "Original", "Score"]]
219
- .style.highlight_between(subset="Score", color="#ffcccc"),
220
- use_container_width=True,
221
- hide_index=True
222
- )
 
 
 
 
 
 
 
 
 
 
223
  else:
224
- st.info(" 未检测到高风险元素")
225
-
226
- # 历史记录
227
- st.markdown("### 📜 分析历史")
228
- history_df = pd.DataFrame(st.session_state.history)
229
- st.dataframe(
230
- history_df[["text", "label", "score"]]
231
- .style.applymap(lambda x: "color: red" if x == "OFFENSIVE" else ""),
232
- use_container_width=True,
233
- hide_index=True
234
- )
235
  else:
236
- st.info("🕑 等待首次分析结果...")
 
6
  import pandas as pd
7
  import plotly.express as px
8
 
9
+ # ✅ Step 1: Emoji 翻译模型(你自己训练的模型)
10
  emoji_model_id = "JenniferHJF/qwen1.5-emoji-finetuned"
11
  emoji_tokenizer = AutoTokenizer.from_pretrained(emoji_model_id, trust_remote_code=True)
12
  emoji_model = AutoModelForCausalLM.from_pretrained(
 
16
  ).to("cuda" if torch.cuda.is_available() else "cpu")
17
  emoji_model.eval()
18
 
19
+ # ✅ Step 2: 可选择的冒犯性文本识别模型
20
  model_options = {
21
  "Toxic-BERT": "unitary/toxic-bert",
22
  "Roberta Offensive": "cardiffnlp/twitter-roberta-base-offensive",
23
  "BERT Emotion": "bhadresh-savani/bert-base-go-emotion"
24
  }
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # ✅ 页面配置
27
  st.set_page_config(page_title="Emoji Offensive Text Detector", page_icon="🚨", layout="wide")
28
 
29
+ # ✅ 侧边栏:模型选择
30
  with st.sidebar:
31
  st.header("🧠 Configuration")
32
  selected_model = st.selectbox("Choose classification model", list(model_options.keys()))
33
  selected_model_id = model_options[selected_model]
34
+ classifier = pipeline("text-classification", model=selected_model_id, device=0 if torch.cuda.is_available() else -1)
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  # 初始化历史记录
37
  if "history" not in st.session_state:
38
  st.session_state.history = []
39
 
40
+ # 分类函数
41
  def classify_emoji_text(text: str):
 
42
  prompt = f"输入:{text}\n输出:"
43
  input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
44
  with torch.no_grad():
 
46
  decoded = emoji_tokenizer.decode(output_ids[0], skip_special_tokens=True)
47
  translated_text = decoded.split("输出:")[-1].strip() if "输出:" in decoded else decoded.strip()
48
 
49
+ result = classifier(translated_text)[0]
50
+ label = result["label"]
51
+ score = result["score"]
52
+ reasoning = (
53
+ f"The sentence was flagged as '{label}' due to potentially offensive phrases. "
54
+ "Consider replacing emotionally charged, ambiguous, or abusive terms."
55
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
 
57
  st.session_state.history.append({
58
  "text": text,
59
  "translated": translated_text,
60
+ "label": label,
61
+ "score": score,
62
+ "reason": reasoning
 
63
  })
64
+ return translated_text, label, score, reasoning
 
65
 
66
+ # 主页面:输入与分析共存
67
  st.title("🚨 Emoji Offensive Text Detector & Analysis Dashboard")
68
 
69
+ # 文本输入
70
  st.subheader("1. 输入与分类")
71
  default_text = "你是🐷"
72
  text = st.text_area("Enter sentence with emojis:", value=default_text, height=150)
 
74
  if st.button("🚦 Analyze Text"):
75
  with st.spinner("🔍 Processing..."):
76
  try:
77
+ translated, label, score, reason = classify_emoji_text(text)
 
78
  st.markdown("**Translated sentence:**")
79
  st.code(translated, language="text")
80
+ st.markdown(f"**Prediction:** {label}")
81
+ st.markdown(f"**Confidence Score:** {score:.2%}")
 
 
 
 
 
82
  st.markdown("**Model Explanation:**")
83
+ st.info(reason)
 
 
 
84
  except Exception as e:
85
+ st.error(f"❌ An error occurred:\n{e}")
86
 
87
+ # 图片上传与 OCR
88
  st.markdown("---")
89
  st.subheader("2. 图片 OCR & 分类")
90
  uploaded_file = st.file_uploader("Upload an image (JPG/PNG)", type=["jpg","jpeg","png"])
91
  if uploaded_file:
92
  image = Image.open(uploaded_file)
93
  st.image(image, caption="Uploaded Screenshot", use_column_width=True)
 
94
  with st.spinner("🧠 Extracting text via OCR..."):
95
  ocr_text = pytesseract.image_to_string(image, lang="chi_sim+eng").strip()
 
96
  if ocr_text:
97
  st.markdown("**Extracted Text:**")
98
  st.code(ocr_text)
99
+ translated, label, score, reason = classify_emoji_text(ocr_text)
100
+ st.markdown("**Translated sentence:**")
101
+ st.code(translated, language="text")
102
+ st.markdown(f"**Prediction:** {label}")
103
+ st.markdown(f"**Confidence Score:** {score:.2%}")
104
+ st.markdown("**Model Explanation:**")
105
+ st.info(reason)
106
  else:
107
+ st.info("⚠️ No text detected in the image.")
108
 
109
+ # 分析仪表盘
110
  st.markdown("---")
111
+ st.subheader("3. Violation Analysis Dashboard")
112
  if st.session_state.history:
113
+ # 展示历史记录
114
+ df = pd.DataFrame(st.session_state.history)
115
+ st.markdown("### 🧾 Offensive Terms & Suggestions")
116
+ for item in st.session_state.history:
117
+ st.markdown(f"- 🔹 **Input:** {item['text']}")
118
+ st.markdown(f" - ✨ **Translated:** {item['translated']}")
119
+ st.markdown(f" - ❗ **Label:** {item['label']} with **{item['score']:.2%}** confidence")
120
+ st.markdown(f" - 🔧 **Suggestion:** {item['reason']}")
121
+
122
  # 雷达图
 
123
  radar_df = pd.DataFrame({
124
+ "Category": ["Insult","Abuse","Discrimination","Hate Speech","Vulgarity"],
125
+ "Score": [0.7,0.4,0.3,0.5,0.6]
126
  })
127
+ radar_fig = px.line_polar(radar_df, r='Score', theta='Category', line_close=True, title="⚠️ Risk Radar by Category")
128
+ radar_fig.update_traces(line_color='black')
129
+ st.plotly_chart(radar_fig)
130
+
131
+ # —— 新增:单词级冒犯性相关性分析 —— #
132
+ st.markdown("### 🧬 Word-level Offensive Correlation")
133
+
134
+ # 取最近一次翻译文本,按空格拆分单词
135
+ last_translated_text = st.session_state.history[-1]["translated"]
136
+ words = last_translated_text.split()
137
+
138
+ # 对每个单词进行分类并收集分数
139
+ word_scores = []
140
+ for word in words:
141
+ try:
142
+ res = classifier(word)[0]
143
+ word_scores.append({
144
+ "Word": word,
145
+ "Label": res["label"],
146
+ "Score": res["score"]
147
+ })
148
+ except Exception:
149
+ continue
150
+
151
+ if word_scores:
152
+ word_df = pd.DataFrame(word_scores)
153
+ word_df = word_df.sort_values(by="Score", ascending=False).reset_index(drop=True)
154
+
155
+ max_display = 5
156
+ # Streamlit 1.22+ 支持 st.toggle,若版本不支持可改用 checkbox
157
+ show_more = st.toggle("Show more words", value=False)
158
+
159
+ display_df = word_df if show_more else word_df.head(max_display)
160
+ # 隐藏边框并渲染 HTML 表格
161
+ st.markdown(
162
+ display_df.to_html(index=False, border=0),
163
+ unsafe_allow_html=True
164
+ )
165
  else:
166
+ st.info(" No word-level analysis available.")
 
 
 
 
 
 
 
 
 
 
167
  else:
168
+ st.info("⚠️ No classification data available yet.")