aeresd commited on
Commit
dc1bdc8
·
verified ·
1 Parent(s): a77ff54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -80
app.py CHANGED
@@ -29,52 +29,31 @@ st.set_page_config(page_title="Emoji Offensive Text Detector", page_icon="🚨",
29
  # ✅ 侧边栏: 选择模型
30
  with st.sidebar:
31
  st.header("🧠 Settings")
32
- selected_model = st.selectbox(
33
- "Choose classification model", list(model_options.keys())
34
- )
35
  selected_model_id = model_options[selected_model]
36
- classifier = pipeline(
37
- "text-classification",
38
- model=selected_model_id,
39
- device=0 if torch.cuda.is_available() else -1
40
- )
41
 
42
  # 初始化历史记录
43
  if "history" not in st.session_state:
44
  st.session_state.history = []
45
 
46
  # 核心函数: 翻译并分类
 
47
  def classify_emoji_text(text: str):
48
  prompt = f"输入:{text}\n输出:"
49
  input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
50
  with torch.no_grad():
51
- output_ids = emoji_model.generate(
52
- **input_ids, max_new_tokens=64, do_sample=False
53
- )
54
- decoded = emoji_tokenizer.decode(
55
- output_ids[0], skip_special_tokens=True
56
- )
57
- translated_text = (
58
- decoded.split("输出:")[-1].strip()
59
- if "输出:" in decoded
60
- else decoded.strip()
61
- )
62
 
63
  result = classifier(translated_text)[0]
64
  label = result["label"]
65
  score = result["score"]
66
- reasoning = (
67
- f"The sentence was flagged as '{label}' due to potentially offensive phrases. "
68
- "Consider replacing emotionally charged, ambiguous, or abusive terms."
69
- )
70
-
71
- st.session_state.history.append({
72
- "text": text,
73
- "translated": translated_text,
74
- "label": label,
75
- "score": score,
76
- "reason": reasoning
77
- })
78
  return translated_text, label, score, reasoning
79
 
80
  # 页面主体
@@ -85,9 +64,7 @@ st.markdown("### ✍️ Input your sentence or upload screenshot:")
85
  col1, col2 = st.columns(2)
86
  with col1:
87
  default_text = "你是🐷"
88
- text = st.text_area(
89
- "Enter sentence with emojis:", value=default_text, height=150
90
- )
91
  if st.button("🚦 Analyze Text"):
92
  with st.spinner("🔍 Processing..."):
93
  try:
@@ -97,23 +74,19 @@ with col1:
97
 
98
  st.markdown(f"#### 🎯 Prediction: {label}")
99
  st.markdown(f"#### 📊 Confidence Score: {score:.2%}")
100
- st.markdown("#### 🧠 Model Explanation:")
101
  st.info(reason)
102
  except Exception as e:
103
  st.error(f"❌ An error occurred during processing:\n\n{e}")
104
 
105
  with col2:
106
- uploaded_file = st.file_uploader(
107
- "Upload an image (JPG/PNG)", type=["jpg", "jpeg", "png"]
108
- )
109
  if uploaded_file is not None:
110
  image = Image.open(uploaded_file)
111
  st.image(image, caption="Uploaded Screenshot", use_column_width=True)
112
  if st.button("🛠️ OCR & Analyze Image"):
113
  with st.spinner("🧠 Extracting text via OCR..."):
114
- ocr_text = pytesseract.image_to_string(
115
- image, lang="chi_sim+eng"
116
- ).strip()
117
  st.markdown("#### 📋 Extracted Text:")
118
  st.code(ocr_text)
119
  classify_emoji_text(ocr_text)
@@ -122,53 +95,26 @@ with col2:
122
  st.markdown("---")
123
  st.title("📊 Violation Analysis Dashboard")
124
  if st.session_state.history:
 
 
 
 
 
 
 
125
  st.markdown("### 🧾 Offensive Terms & Suggestions")
126
  for item in st.session_state.history:
127
  st.markdown(f"- 🔹 **Input:** {item['text']}")
128
  st.markdown(f" - ✨ **Translated:** {item['translated']}")
129
- st.markdown(
130
- f" - **Label:** {item['label']} with **{item['score']:.2%}** confidence"
131
- )
132
- st.markdown(f" - 🔧 **Suggestion:** {item['reason']} ")
133
 
134
- # 雷达图
135
  radar_df = pd.DataFrame({
136
  "Category": ["Insult", "Abuse", "Discrimination", "Hate Speech", "Vulgarity"],
137
  "Score": [0.7, 0.4, 0.3, 0.5, 0.6]
138
  })
139
- radar_fig = px.line_polar(
140
- radar_df,
141
- r='Score',
142
- theta='Category',
143
- line_close=True,
144
- title="⚠️ Risk Radar by Category",
145
- color_discrete_sequence=['black'],
146
- template='simple_white'
147
- )
148
- radar_fig.update_layout(
149
- polar=dict(
150
- gridshape='circular',
151
- bgcolor='white',
152
- radialaxis=dict(
153
- showticklabels=False,
154
- ticks='',
155
- showgrid=True,
156
- gridcolor='lightgrey',
157
- gridwidth=1,
158
- linecolor='black',
159
- linewidth=2
160
- ),
161
- angularaxis=dict(
162
- showticklabels=False,
163
- ticks='',
164
- showline=True,
165
- linecolor='black',
166
- linewidth=2
167
- )
168
- ),
169
- paper_bgcolor='white',
170
- plot_bgcolor='white'
171
- )
172
  st.plotly_chart(radar_fig)
173
  else:
174
- st.info("⚠️ No classification data available yet.")
 
29
  # ✅ 侧边栏: 选择模型
30
  with st.sidebar:
31
  st.header("🧠 Settings")
32
+ moderation_type = st.selectbox("Select Task Type", ["Normal Text", "Bullet Screen Text"])
33
+ selected_model = st.selectbox("Choose classification model", list(model_options.keys()))
 
34
  selected_model_id = model_options[selected_model]
35
+ classifier = pipeline("text-classification", model=selected_model_id, device=0 if torch.cuda.is_available() else -1)
 
 
 
 
36
 
37
  # 初始化历史记录
38
  if "history" not in st.session_state:
39
  st.session_state.history = []
40
 
41
  # 核心函数: 翻译并分类
42
+
43
  def classify_emoji_text(text: str):
44
  prompt = f"输入:{text}\n输出:"
45
  input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
46
  with torch.no_grad():
47
+ output_ids = emoji_model.generate(**input_ids, max_new_tokens=64, do_sample=False)
48
+ decoded = emoji_tokenizer.decode(output_ids[0], skip_special_tokens=True)
49
+ translated_text = decoded.split("输出:")[-1].strip() if "输出:" in decoded else decoded.strip()
 
 
 
 
 
 
 
 
50
 
51
  result = classifier(translated_text)[0]
52
  label = result["label"]
53
  score = result["score"]
54
+ reasoning = f"The sentence was flagged as '{label}' due to potentially offensive phrases. Consider replacing emotionally charged, ambiguous, or abusive terms."
55
+
56
+ st.session_state.history.append({"text": text, "translated": translated_text, "label": label, "score": score, "reason": reasoning})
 
 
 
 
 
 
 
 
 
57
  return translated_text, label, score, reasoning
58
 
59
  # 页面主体
 
64
  col1, col2 = st.columns(2)
65
  with col1:
66
  default_text = "你是🐷"
67
+ text = st.text_area("Enter sentence with emojis:", value=default_text, height=150)
 
 
68
  if st.button("🚦 Analyze Text"):
69
  with st.spinner("🔍 Processing..."):
70
  try:
 
74
 
75
  st.markdown(f"#### 🎯 Prediction: {label}")
76
  st.markdown(f"#### 📊 Confidence Score: {score:.2%}")
77
+ st.markdown(f"#### 🧠 Model Explanation:")
78
  st.info(reason)
79
  except Exception as e:
80
  st.error(f"❌ An error occurred during processing:\n\n{e}")
81
 
82
  with col2:
83
+ uploaded_file = st.file_uploader("Upload an image (JPG/PNG)", type=["jpg", "jpeg", "png"])
 
 
84
  if uploaded_file is not None:
85
  image = Image.open(uploaded_file)
86
  st.image(image, caption="Uploaded Screenshot", use_column_width=True)
87
  if st.button("🛠️ OCR & Analyze Image"):
88
  with st.spinner("🧠 Extracting text via OCR..."):
89
+ ocr_text = pytesseract.image_to_string(image, lang="chi_sim+eng").strip()
 
 
90
  st.markdown("#### 📋 Extracted Text:")
91
  st.code(ocr_text)
92
  classify_emoji_text(ocr_text)
 
95
  st.markdown("---")
96
  st.title("📊 Violation Analysis Dashboard")
97
  if st.session_state.history:
98
+ df = pd.DataFrame(st.session_state.history)
99
+ # 饼图
100
+ label_counts = df["label"].value_counts().reset_index()
101
+ label_counts.columns = ["Category", "Count"]
102
+ fig = px.pie(label_counts, names="Category", values="Count", title="Offensive Category Distribution")
103
+ st.plotly_chart(fig)
104
+
105
  st.markdown("### 🧾 Offensive Terms & Suggestions")
106
  for item in st.session_state.history:
107
  st.markdown(f"- 🔹 **Input:** {item['text']}")
108
  st.markdown(f" - ✨ **Translated:** {item['translated']}")
109
+ st.markdown(f" - ❗ **Label:** {item['label']} with **{item['score']:.2%}** confidence")
110
+ st.markdown(f" - 🔧 **Suggestion:** {item['reason']}")
 
 
111
 
112
+ # 雷达图演示示例(可替换为动态数据)
113
  radar_df = pd.DataFrame({
114
  "Category": ["Insult", "Abuse", "Discrimination", "Hate Speech", "Vulgarity"],
115
  "Score": [0.7, 0.4, 0.3, 0.5, 0.6]
116
  })
117
+ radar_fig = px.line_polar(radar_df, r='Score', theta='Category', line_close=True, title="⚠️ Risk Radar by Category")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  st.plotly_chart(radar_fig)
119
  else:
120
+ st.info("⚠️ No classification data available yet.")