aeresd commited on
Commit
68c3cca
·
verified ·
1 Parent(s): dc1bdc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -58
app.py CHANGED
@@ -6,7 +6,7 @@ import pytesseract
6
  import pandas as pd
7
  import plotly.express as px
8
 
9
- # Step 1: Emoji 翻译模型(你自己训练的模型)
10
  emoji_model_id = "JenniferHJF/qwen1.5-emoji-finetuned"
11
  emoji_tokenizer = AutoTokenizer.from_pretrained(emoji_model_id, trust_remote_code=True)
12
  emoji_model = AutoModelForCausalLM.from_pretrained(
@@ -16,30 +16,34 @@ emoji_model = AutoModelForCausalLM.from_pretrained(
16
  ).to("cuda" if torch.cuda.is_available() else "cpu")
17
  emoji_model.eval()
18
 
19
- # Step 2: 可选择的冒犯性文本识别模型
20
  model_options = {
21
  "Toxic-BERT": "unitary/toxic-bert",
22
  "Roberta Offensive": "cardiffnlp/twitter-roberta-base-offensive",
23
  "BERT Emotion": "bhadresh-savani/bert-base-go-emotion"
24
  }
25
 
26
- # 页面配置
27
  st.set_page_config(page_title="Emoji Offensive Text Detector", page_icon="🚨", layout="wide")
28
 
29
- # ✅ 侧边栏: 选择模型
30
  with st.sidebar:
31
- st.header("🧠 Settings")
32
- moderation_type = st.selectbox("Select Task Type", ["Normal Text", "Bullet Screen Text"])
33
- selected_model = st.selectbox("Choose classification model", list(model_options.keys()))
34
- selected_model_id = model_options[selected_model]
35
- classifier = pipeline("text-classification", model=selected_model_id, device=0 if torch.cuda.is_available() else -1)
 
 
 
 
 
36
 
37
  # 初始化历史记录
38
  if "history" not in st.session_state:
39
  st.session_state.history = []
40
 
41
- # 核心函数: 翻译并分类
42
-
43
  def classify_emoji_text(text: str):
44
  prompt = f"输入:{text}\n输出:"
45
  input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
@@ -56,65 +60,66 @@ def classify_emoji_text(text: str):
56
  st.session_state.history.append({"text": text, "translated": translated_text, "label": label, "score": score, "reason": reasoning})
57
  return translated_text, label, score, reasoning
58
 
59
- # 页面主体
60
- st.title("🚨 Emoji Offensive Text Detector & Analysis")
61
-
62
- # 输入区域
63
- st.markdown("### ✍️ Input your sentence or upload screenshot:")
64
- col1, col2 = st.columns(2)
65
- with col1:
66
  default_text = "你是🐷"
67
  text = st.text_area("Enter sentence with emojis:", value=default_text, height=150)
68
- if st.button("🚦 Analyze Text"):
 
69
  with st.spinner("🔍 Processing..."):
70
  try:
71
  translated, label, score, reason = classify_emoji_text(text)
72
- st.markdown("#### 🔄 Translated sentence:")
73
  st.code(translated, language="text")
74
 
75
- st.markdown(f"#### 🎯 Prediction: {label}")
76
- st.markdown(f"#### 📊 Confidence Score: {score:.2%}")
77
- st.markdown(f"#### 🧠 Model Explanation:")
78
  st.info(reason)
 
79
  except Exception as e:
80
  st.error(f"❌ An error occurred during processing:\n\n{e}")
81
 
82
- with col2:
 
 
83
  uploaded_file = st.file_uploader("Upload an image (JPG/PNG)", type=["jpg", "jpeg", "png"])
 
84
  if uploaded_file is not None:
85
  image = Image.open(uploaded_file)
86
  st.image(image, caption="Uploaded Screenshot", use_column_width=True)
87
- if st.button("🛠️ OCR & Analyze Image"):
88
- with st.spinner("🧠 Extracting text via OCR..."):
89
- ocr_text = pytesseract.image_to_string(image, lang="chi_sim+eng").strip()
90
- st.markdown("#### 📋 Extracted Text:")
91
- st.code(ocr_text)
92
- classify_emoji_text(ocr_text)
93
-
94
- # 分析仪表盘
95
- st.markdown("---")
96
- st.title("📊 Violation Analysis Dashboard")
97
- if st.session_state.history:
98
- df = pd.DataFrame(st.session_state.history)
99
- # 饼图
100
- label_counts = df["label"].value_counts().reset_index()
101
- label_counts.columns = ["Category", "Count"]
102
- fig = px.pie(label_counts, names="Category", values="Count", title="Offensive Category Distribution")
103
- st.plotly_chart(fig)
104
-
105
- st.markdown("### 🧾 Offensive Terms & Suggestions")
106
- for item in st.session_state.history:
107
- st.markdown(f"- 🔹 **Input:** {item['text']}")
108
- st.markdown(f" - ✨ **Translated:** {item['translated']}")
109
- st.markdown(f" - ❗ **Label:** {item['label']} with **{item['score']:.2%}** confidence")
110
- st.markdown(f" - 🔧 **Suggestion:** {item['reason']}")
111
-
112
- # 雷达图演示示例(可替换为动态数据)
113
- radar_df = pd.DataFrame({
114
- "Category": ["Insult", "Abuse", "Discrimination", "Hate Speech", "Vulgarity"],
115
- "Score": [0.7, 0.4, 0.3, 0.5, 0.6]
116
- })
117
- radar_fig = px.line_polar(radar_df, r='Score', theta='Category', line_close=True, title="⚠️ Risk Radar by Category")
118
- st.plotly_chart(radar_fig)
119
- else:
120
- st.info("⚠️ No classification data available yet.")
 
6
  import pandas as pd
7
  import plotly.express as px
8
 
9
+ # Step 1: Emoji 翻译模型(你自己训练的模型)
10
  emoji_model_id = "JenniferHJF/qwen1.5-emoji-finetuned"
11
  emoji_tokenizer = AutoTokenizer.from_pretrained(emoji_model_id, trust_remote_code=True)
12
  emoji_model = AutoModelForCausalLM.from_pretrained(
 
16
  ).to("cuda" if torch.cuda.is_available() else "cpu")
17
  emoji_model.eval()
18
 
19
+ # Step 2: 可选择的冒犯性文本识别模型
20
  model_options = {
21
  "Toxic-BERT": "unitary/toxic-bert",
22
  "Roberta Offensive": "cardiffnlp/twitter-roberta-base-offensive",
23
  "BERT Emotion": "bhadresh-savani/bert-base-go-emotion"
24
  }
25
 
26
+ # 页面配置
27
  st.set_page_config(page_title="Emoji Offensive Text Detector", page_icon="🚨", layout="wide")
28
 
29
+ # 页面布局
30
  with st.sidebar:
31
+ st.header("🧠 Navigation")
32
+ section = st.radio("Select Mode:", ["📍 Text Moderation", "📊 Text Analysis"])
33
+
34
+ if section == "📍 Text Moderation":
35
+ selected_model = st.selectbox("Choose classification model", list(model_options.keys()))
36
+ selected_model_id = model_options[selected_model]
37
+ classifier = pipeline("text-classification", model=selected_model_id, device=0 if torch.cuda.is_available() else -1)
38
+
39
+ elif section == "📊 Text Analysis":
40
+ st.markdown("You can view editing suggestions based on past analyses.")
41
 
42
  # 初始化历史记录
43
  if "history" not in st.session_state:
44
  st.session_state.history = []
45
 
46
+ # Emoji 文本翻译与分类函数
 
47
  def classify_emoji_text(text: str):
48
  prompt = f"输入:{text}\n输出:"
49
  input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
 
60
  st.session_state.history.append({"text": text, "translated": translated_text, "label": label, "score": score, "reason": reasoning})
61
  return translated_text, label, score, reasoning
62
 
63
+ # 功能逻辑
64
+ if section == "📍 Text Moderation":
65
+ st.title("📍 Offensive Text Classification")
66
+ st.markdown("### ✍️ Input your sentence:")
 
 
 
67
  default_text = "你是🐷"
68
  text = st.text_area("Enter sentence with emojis:", value=default_text, height=150)
69
+
70
+ if st.button("🚦 Analyze"):
71
  with st.spinner("🔍 Processing..."):
72
  try:
73
  translated, label, score, reason = classify_emoji_text(text)
74
+ st.markdown("### 🔄 Translated sentence:")
75
  st.code(translated, language="text")
76
 
77
+ st.markdown(f"### 🎯 Prediction: `{label}`")
78
+ st.markdown(f"### 📊 Confidence Score: `{score:.2%}`")
79
+ st.markdown("### 🧠 Model Explanation:")
80
  st.info(reason)
81
+
82
  except Exception as e:
83
  st.error(f"❌ An error occurred during processing:\n\n{e}")
84
 
85
+ st.markdown("---")
86
+ st.markdown("### 🖼️ Or upload a screenshot of bullet comments:")
87
+
88
  uploaded_file = st.file_uploader("Upload an image (JPG/PNG)", type=["jpg", "jpeg", "png"])
89
+
90
  if uploaded_file is not None:
91
  image = Image.open(uploaded_file)
92
  st.image(image, caption="Uploaded Screenshot", use_column_width=True)
93
+
94
+ with st.spinner("🧠 Extracting text via OCR..."):
95
+ ocr_text = pytesseract.image_to_string(image, lang="chi_sim+eng")
96
+ st.markdown("#### 📋 Extracted Text:")
97
+ st.code(ocr_text.strip())
98
+
99
+ translated, label, score, reason = classify_emoji_text(ocr_text.strip())
100
+ st.markdown("### 🔄 Translated sentence:")
101
+ st.code(translated, language="text")
102
+
103
+ st.markdown(f"### 🎯 Prediction: `{label}`")
104
+ st.markdown(f"### 📊 Confidence Score: `{score:.2%}`")
105
+ st.markdown("### 🧠 Model Explanation:")
106
+ st.info(reason)
107
+
108
+ elif section == "📊 Text Analysis":
109
+ st.title("📊 Violation Analysis Dashboard")
110
+ if st.session_state.history:
111
+ st.markdown("### 🧾 Offensive Terms & Suggestions")
112
+ for item in st.session_state.history:
113
+ st.markdown(f"- 🔹 **Input:** `{item['text']}`")
114
+ st.markdown(f" - ✨ **Translated:** `{item['translated']}`")
115
+ st.markdown(f" - ❗ **Label:** `{item['label']}` with **{item['score']:.2%}** confidence")
116
+ st.markdown(f" - 🔧 **Suggestion:** {item['reason']}")
117
+
118
+ radar_df = pd.DataFrame({
119
+ "Category": ["Insult", "Abuse", "Discrimination", "Hate Speech", "Vulgarity"],
120
+ "Score": [0.7, 0.4, 0.3, 0.5, 0.6]
121
+ })
122
+ radar_fig = px.line_polar(radar_df, r='Score', theta='Category', line_close=True, title="⚠️ Risk Radar by Category")
123
+ st.plotly_chart(radar_fig)
124
+ else:
125
+ st.info("⚠️ No classification data available yet.")