aeresd commited on
Commit
6e7a57d
·
verified ·
1 Parent(s): cd7f587

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -13
app.py CHANGED
@@ -29,10 +29,15 @@ st.set_page_config(page_title="Emoji Offensive Text Detector", page_icon="🚨",
29
  # ✅ 侧边栏: 选择模型
30
  with st.sidebar:
31
  st.header("🧠 Settings")
32
- selected_model = st.selectbox("Choose classification model", list(model_options.keys()))
 
 
33
  selected_model_id = model_options[selected_model]
34
- classifier = pipeline("text-classification", model=selected_model_id,
35
- device=0 if torch.cuda.is_available() else -1)
 
 
 
36
 
37
  # 初始化历史记录
38
  if "history" not in st.session_state:
@@ -43,9 +48,17 @@ def classify_emoji_text(text: str):
43
  prompt = f"输入:{text}\n输出:"
44
  input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
45
  with torch.no_grad():
46
- output_ids = emoji_model.generate(**input_ids, max_new_tokens=64, do_sample=False)
47
- decoded = emoji_tokenizer.decode(output_ids[0], skip_special_tokens=True)
48
- translated_text = decoded.split("输出:")[-1].strip() if "输出:" in decoded else decoded.strip()
 
 
 
 
 
 
 
 
49
 
50
  result = classifier(translated_text)[0]
51
  label = result["label"]
@@ -72,7 +85,9 @@ st.markdown("### ✍️ Input your sentence or upload screenshot:")
72
  col1, col2 = st.columns(2)
73
  with col1:
74
  default_text = "你是🐷"
75
- text = st.text_area("Enter sentence with emojis:", value=default_text, height=150)
 
 
76
  if st.button("🚦 Analyze Text"):
77
  with st.spinner("🔍 Processing..."):
78
  try:
@@ -88,13 +103,17 @@ with col1:
88
  st.error(f"❌ An error occurred during processing:\n\n{e}")
89
 
90
  with col2:
91
- uploaded_file = st.file_uploader("Upload an image (JPG/PNG)", type=["jpg", "jpeg", "png"])
 
 
92
  if uploaded_file is not None:
93
  image = Image.open(uploaded_file)
94
  st.image(image, caption="Uploaded Screenshot", use_column_width=True)
95
  if st.button("🛠️ OCR & Analyze Image"):
96
  with st.spinner("🧠 Extracting text via OCR..."):
97
- ocr_text = pytesseract.image_to_string(image, lang="chi_sim+eng").strip()
 
 
98
  st.markdown("#### 📋 Extracted Text:")
99
  st.code(ocr_text)
100
  classify_emoji_text(ocr_text)
@@ -102,16 +121,17 @@ with col2:
102
  # 分析仪表盘
103
  st.markdown("---")
104
  st.title("📊 Violation Analysis Dashboard")
105
-
106
  if st.session_state.history:
107
  st.markdown("### 🧾 Offensive Terms & Suggestions")
108
  for item in st.session_state.history:
109
  st.markdown(f"- 🔹 **Input:** {item['text']}")
110
  st.markdown(f" - ✨ **Translated:** {item['translated']}")
111
- st.markdown(f" - ❗ **Label:** {item['label']} with **{item['score']:.2%}** confidence")
112
- st.markdown(f" - 🔧 **Suggestion:** {item['reason']}")
 
 
113
 
114
- # 雷达图演示示例(可替换为动态数据)
115
  radar_df = pd.DataFrame({
116
  "Category": ["Insult", "Abuse", "Discrimination", "Hate Speech", "Vulgarity"],
117
  "Score": [0.7, 0.4, 0.3, 0.5, 0.6]
 
29
  # ✅ 侧边栏: 选择模型
30
  with st.sidebar:
31
  st.header("🧠 Settings")
32
+ selected_model = st.selectbox(
33
+ "Choose classification model", list(model_options.keys())
34
+ )
35
  selected_model_id = model_options[selected_model]
36
+ classifier = pipeline(
37
+ "text-classification",
38
+ model=selected_model_id,
39
+ device=0 if torch.cuda.is_available() else -1
40
+ )
41
 
42
  # 初始化历史记录
43
  if "history" not in st.session_state:
 
48
  prompt = f"输入:{text}\n输出:"
49
  input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
50
  with torch.no_grad():
51
+ output_ids = emoji_model.generate(
52
+ **input_ids, max_new_tokens=64, do_sample=False
53
+ )
54
+ decoded = emoji_tokenizer.decode(
55
+ output_ids[0], skip_special_tokens=True
56
+ )
57
+ translated_text = (
58
+ decoded.split("输出:")[-1].strip()
59
+ if "输出:" in decoded
60
+ else decoded.strip()
61
+ )
62
 
63
  result = classifier(translated_text)[0]
64
  label = result["label"]
 
85
  col1, col2 = st.columns(2)
86
  with col1:
87
  default_text = "你是🐷"
88
+ text = st.text_area(
89
+ "Enter sentence with emojis:", value=default_text, height=150
90
+ )
91
  if st.button("🚦 Analyze Text"):
92
  with st.spinner("🔍 Processing..."):
93
  try:
 
103
  st.error(f"❌ An error occurred during processing:\n\n{e}")
104
 
105
  with col2:
106
+ uploaded_file = st.file_uploader(
107
+ "Upload an image (JPG/PNG)", type=["jpg", "jpeg", "png"]
108
+ )
109
  if uploaded_file is not None:
110
  image = Image.open(uploaded_file)
111
  st.image(image, caption="Uploaded Screenshot", use_column_width=True)
112
  if st.button("🛠️ OCR & Analyze Image"):
113
  with st.spinner("🧠 Extracting text via OCR..."):
114
+ ocr_text = pytesseract.image_to_string(
115
+ image, lang="chi_sim+eng"
116
+ ).strip()
117
  st.markdown("#### 📋 Extracted Text:")
118
  st.code(ocr_text)
119
  classify_emoji_text(ocr_text)
 
121
  # 分析仪表盘
122
  st.markdown("---")
123
  st.title("📊 Violation Analysis Dashboard")
 
124
  if st.session_state.history:
125
  st.markdown("### 🧾 Offensive Terms & Suggestions")
126
  for item in st.session_state.history:
127
  st.markdown(f"- 🔹 **Input:** {item['text']}")
128
  st.markdown(f" - ✨ **Translated:** {item['translated']}")
129
+ st.markdown(
130
+ f" - **Label:** {item['label']} with **{item['score']:.2%}** confidence"
131
+ )
132
+ st.markdown(f" - 🔧 **Suggestion:** {item['reason']} ")
133
 
134
+ # 雷达图
135
  radar_df = pd.DataFrame({
136
  "Category": ["Insult", "Abuse", "Discrimination", "Hate Speech", "Vulgarity"],
137
  "Score": [0.7, 0.4, 0.3, 0.5, 0.6]