Update app.py
Browse files
app.py
CHANGED
@@ -31,14 +31,14 @@ with st.sidebar:
|
|
31 |
st.header("🧠 Settings")
|
32 |
selected_model = st.selectbox("Choose classification model", list(model_options.keys()))
|
33 |
selected_model_id = model_options[selected_model]
|
34 |
-
classifier = pipeline("text-classification", model=selected_model_id,
|
|
|
35 |
|
36 |
# 初始化历史记录
|
37 |
if "history" not in st.session_state:
|
38 |
st.session_state.history = []
|
39 |
|
40 |
# 核心函数: 翻译并分类
|
41 |
-
|
42 |
def classify_emoji_text(text: str):
|
43 |
prompt = f"输入:{text}\n输出:"
|
44 |
input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
|
@@ -50,9 +50,18 @@ def classify_emoji_text(text: str):
|
|
50 |
result = classifier(translated_text)[0]
|
51 |
label = result["label"]
|
52 |
score = result["score"]
|
53 |
-
reasoning =
|
|
|
|
|
|
|
54 |
|
55 |
-
st.session_state.history.append({
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
return translated_text, label, score, reasoning
|
57 |
|
58 |
# 页面主体
|
@@ -73,7 +82,7 @@ with col1:
|
|
73 |
|
74 |
st.markdown(f"#### 🎯 Prediction: {label}")
|
75 |
st.markdown(f"#### 📊 Confidence Score: {score:.2%}")
|
76 |
-
st.markdown(
|
77 |
st.info(reason)
|
78 |
except Exception as e:
|
79 |
st.error(f"❌ An error occurred during processing:\n\n{e}")
|
@@ -93,13 +102,14 @@ with col2:
|
|
93 |
# 分析仪表盘
|
94 |
st.markdown("---")
|
95 |
st.title("📊 Violation Analysis Dashboard")
|
|
|
96 |
if st.session_state.history:
|
97 |
st.markdown("### 🧾 Offensive Terms & Suggestions")
|
98 |
for item in st.session_state.history:
|
99 |
st.markdown(f"- 🔹 **Input:** {item['text']}")
|
100 |
st.markdown(f" - ✨ **Translated:** {item['translated']}")
|
101 |
st.markdown(f" - ❗ **Label:** {item['label']} with **{item['score']:.2%}** confidence")
|
102 |
-
st.markdown(f" - 🔧 **Suggestion:** {item['reason']}
|
103 |
|
104 |
# 雷达图演示示例(可替换为动态数据)
|
105 |
radar_df = pd.DataFrame({
|
|
|
31 |
st.header("🧠 Settings")
|
32 |
selected_model = st.selectbox("Choose classification model", list(model_options.keys()))
|
33 |
selected_model_id = model_options[selected_model]
|
34 |
+
classifier = pipeline("text-classification", model=selected_model_id,
|
35 |
+
device=0 if torch.cuda.is_available() else -1)
|
36 |
|
37 |
# 初始化历史记录
|
38 |
if "history" not in st.session_state:
|
39 |
st.session_state.history = []
|
40 |
|
41 |
# 核心函数: 翻译并分类
|
|
|
42 |
def classify_emoji_text(text: str):
|
43 |
prompt = f"输入:{text}\n输出:"
|
44 |
input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
|
|
|
50 |
result = classifier(translated_text)[0]
|
51 |
label = result["label"]
|
52 |
score = result["score"]
|
53 |
+
reasoning = (
|
54 |
+
f"The sentence was flagged as '{label}' due to potentially offensive phrases. "
|
55 |
+
"Consider replacing emotionally charged, ambiguous, or abusive terms."
|
56 |
+
)
|
57 |
|
58 |
+
st.session_state.history.append({
|
59 |
+
"text": text,
|
60 |
+
"translated": translated_text,
|
61 |
+
"label": label,
|
62 |
+
"score": score,
|
63 |
+
"reason": reasoning
|
64 |
+
})
|
65 |
return translated_text, label, score, reasoning
|
66 |
|
67 |
# 页面主体
|
|
|
82 |
|
83 |
st.markdown(f"#### 🎯 Prediction: {label}")
|
84 |
st.markdown(f"#### 📊 Confidence Score: {score:.2%}")
|
85 |
+
st.markdown("#### 🧠 Model Explanation:")
|
86 |
st.info(reason)
|
87 |
except Exception as e:
|
88 |
st.error(f"❌ An error occurred during processing:\n\n{e}")
|
|
|
102 |
# 分析仪表盘
|
103 |
st.markdown("---")
|
104 |
st.title("📊 Violation Analysis Dashboard")
|
105 |
+
|
106 |
if st.session_state.history:
|
107 |
st.markdown("### 🧾 Offensive Terms & Suggestions")
|
108 |
for item in st.session_state.history:
|
109 |
st.markdown(f"- 🔹 **Input:** {item['text']}")
|
110 |
st.markdown(f" - ✨ **Translated:** {item['translated']}")
|
111 |
st.markdown(f" - ❗ **Label:** {item['label']} with **{item['score']:.2%}** confidence")
|
112 |
+
st.markdown(f" - 🔧 **Suggestion:** {item['reason']}")
|
113 |
|
114 |
# 雷达图演示示例(可替换为动态数据)
|
115 |
radar_df = pd.DataFrame({
|