test_1 / app.py
aeresd's picture
Update app.py
851f89d verified
raw
history blame
5.22 kB
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
import streamlit as st
from PIL import Image
import pytesseract
import pandas as pd
import plotly.express as px
# Step 1: Emoji 翻译模型(你自己训练的模型)
emoji_model_id = "JenniferHJF/qwen1.5-emoji-finetuned"
emoji_tokenizer = AutoTokenizer.from_pretrained(emoji_model_id, trust_remote_code=True)
emoji_model = AutoModelForCausalLM.from_pretrained(
emoji_model_id,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to("cuda" if torch.cuda.is_available() else "cpu")
emoji_model.eval()
# Step 2: 可选择的冒犯性文本识别模型
model_options = {
"Toxic-BERT": "unitary/toxic-bert",
"Roberta Offensive": "cardiffnlp/twitter-roberta-base-offensive",
"BERT Emotion": "bhadresh-savani/bert-base-go-emotion"
}
# 页面配置
st.set_page_config(page_title="Emoji Offensive Text Detector", page_icon="🚨", layout="wide")
# 初始化历史记录
if "history" not in st.session_state:
st.session_state.history = []
# Emoji 文本翻译与分类函数
def classify_emoji_text(text: str):
prompt = f"输入:{text}\n输出:"
input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
with torch.no_grad():
output_ids = emoji_model.generate(**input_ids, max_new_tokens=64, do_sample=False)
decoded = emoji_tokenizer.decode(output_ids[0], skip_special_tokens=True)
translated_text = decoded.split("输出:")[-1].strip() if "输出:" in decoded else decoded.strip()
result = classifier(translated_text)[0]
label = result["label"]
score = result["score"]
reasoning = f"The sentence was flagged as '{label}' due to potentially offensive phrases. Consider replacing emotionally charged, ambiguous, or abusive terms."
st.session_state.history.append({"text": text, "translated": translated_text, "label": label, "score": score, "reason": reasoning})
return translated_text, label, score, reasoning
# 页面布局
st.sidebar.header("🧠 Settings")
selected_model = st.sidebar.selectbox("Choose classification model", list(model_options.keys()))
selected_model_id = model_options[selected_model]
classifier = pipeline("text-classification", model=selected_model_id, device=0 if torch.cuda.is_available() else -1)
# 主页面:集成 Text Moderation 和 Text Analysis
st.title("🚨 Emoji Offensive Text Detector & Violation Analysis")
# 输入与分类
st.markdown("## ✍️ 输入或上传文本进行分类")
col1, col2 = st.columns([2,1])
with col1:
text = st.text_area("Enter sentence with emojis:", value="你是🐷", height=150)
if st.button("🚦 Analyze Text"):
with st.spinner("🔍 Processing..."):
try:
translated, label, score, reason = classify_emoji_text(text)
st.markdown("### 🔄 Translated sentence:")
st.code(translated, language="text")
st.markdown(f"### 🎯 Prediction: {label}")
st.markdown(f"### 📊 Confidence Score: {score:.2%}")
st.markdown("### 🧠 Model Explanation:")
st.info(reason)
except Exception as e:
st.error(f"❌ Error during processing: {e}")
with col2:
st.markdown("### 🖼️ Or upload a screenshot:")
uploaded_file = st.file_uploader("Image (JPG/PNG)", type=["jpg","png","jpeg"])
if uploaded_file:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_column_width=True)
with st.spinner("🧠 Running OCR..."):
ocr_text = pytesseract.image_to_string(image, lang="chi_sim+eng").strip()
st.markdown("#### 📋 OCR Extracted Text:")
st.code(ocr_text)
translated, label, score, reason = classify_emoji_text(ocr_text)
st.markdown("#### 🔄 Translated:")
st.code(translated)
st.markdown(f"#### 🎯 Prediction: {label}")
st.markdown(f"#### 📊 Confidence: {score:.2%}")
st.markdown("#### 🧠 Explanation:")
st.info(reason)
st.markdown("---")
# 违规分析仪表盘
st.markdown("## 📊 Violation Analysis Dashboard")
if st.session_state.history:
df = pd.DataFrame(st.session_state.history)
st.markdown("### 🧾 历史记录详情")
for item in st.session_state.history:
st.markdown(f"- 🔹 **input:** {item['text']} | **Label:** {item['label']} | **Confidence:** {item['score']:.2%}")
st.markdown(f" - **Translated:** {item['translated']}")
st.markdown(f" - **Suggestion:** {item['reason']}")
radar_df = pd.DataFrame({
"Category": ["Insult","Abuse","Discrimination","Hate Speech","Vulgarity"],
"Score": [0.7,0.4,0.3,0.5,0.6]
})
# 优化雷达图,设置线条为黑色
radar_fig = px.line_polar(
radar_df,
r='Score',
theta='Category',
line_close=True,
title="⚠️ Risk Radar by Category",
color_discrete_sequence=['black']
)
st.plotly_chart(radar_fig)
else:
st.info("⚠️ No data available. Please analyze some text first.")