test_1 / app.py
aeresd's picture
Update app.py
a77ff54 verified
raw
history blame
5.94 kB
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
import streamlit as st
from PIL import Image
import pytesseract
import pandas as pd
import plotly.express as px
# ✅ Step 1: Emoji 翻译模型(你自己训练的模型)
emoji_model_id = "JenniferHJF/qwen1.5-emoji-finetuned"
emoji_tokenizer = AutoTokenizer.from_pretrained(emoji_model_id, trust_remote_code=True)
emoji_model = AutoModelForCausalLM.from_pretrained(
emoji_model_id,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to("cuda" if torch.cuda.is_available() else "cpu")
emoji_model.eval()
# ✅ Step 2: 可选择的冒犯性文本识别模型
model_options = {
"Toxic-BERT": "unitary/toxic-bert",
"Roberta Offensive": "cardiffnlp/twitter-roberta-base-offensive",
"BERT Emotion": "bhadresh-savani/bert-base-go-emotion"
}
# ✅ 页面配置
st.set_page_config(page_title="Emoji Offensive Text Detector", page_icon="🚨", layout="wide")
# ✅ 侧边栏: 选择模型
with st.sidebar:
st.header("🧠 Settings")
selected_model = st.selectbox(
"Choose classification model", list(model_options.keys())
)
selected_model_id = model_options[selected_model]
classifier = pipeline(
"text-classification",
model=selected_model_id,
device=0 if torch.cuda.is_available() else -1
)
# 初始化历史记录
if "history" not in st.session_state:
st.session_state.history = []
# 核心函数: 翻译并分类
def classify_emoji_text(text: str):
prompt = f"输入:{text}\n输出:"
input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
with torch.no_grad():
output_ids = emoji_model.generate(
**input_ids, max_new_tokens=64, do_sample=False
)
decoded = emoji_tokenizer.decode(
output_ids[0], skip_special_tokens=True
)
translated_text = (
decoded.split("输出:")[-1].strip()
if "输出:" in decoded
else decoded.strip()
)
result = classifier(translated_text)[0]
label = result["label"]
score = result["score"]
reasoning = (
f"The sentence was flagged as '{label}' due to potentially offensive phrases. "
"Consider replacing emotionally charged, ambiguous, or abusive terms."
)
st.session_state.history.append({
"text": text,
"translated": translated_text,
"label": label,
"score": score,
"reason": reasoning
})
return translated_text, label, score, reasoning
# 页面主体
st.title("🚨 Emoji Offensive Text Detector & Analysis")
# 输入区域
st.markdown("### ✍️ Input your sentence or upload screenshot:")
col1, col2 = st.columns(2)
with col1:
default_text = "你是🐷"
text = st.text_area(
"Enter sentence with emojis:", value=default_text, height=150
)
if st.button("🚦 Analyze Text"):
with st.spinner("🔍 Processing..."):
try:
translated, label, score, reason = classify_emoji_text(text)
st.markdown("#### 🔄 Translated sentence:")
st.code(translated, language="text")
st.markdown(f"#### 🎯 Prediction: {label}")
st.markdown(f"#### 📊 Confidence Score: {score:.2%}")
st.markdown("#### 🧠 Model Explanation:")
st.info(reason)
except Exception as e:
st.error(f"❌ An error occurred during processing:\n\n{e}")
with col2:
uploaded_file = st.file_uploader(
"Upload an image (JPG/PNG)", type=["jpg", "jpeg", "png"]
)
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Screenshot", use_column_width=True)
if st.button("🛠️ OCR & Analyze Image"):
with st.spinner("🧠 Extracting text via OCR..."):
ocr_text = pytesseract.image_to_string(
image, lang="chi_sim+eng"
).strip()
st.markdown("#### 📋 Extracted Text:")
st.code(ocr_text)
classify_emoji_text(ocr_text)
# 分析仪表盘
st.markdown("---")
st.title("📊 Violation Analysis Dashboard")
if st.session_state.history:
st.markdown("### 🧾 Offensive Terms & Suggestions")
for item in st.session_state.history:
st.markdown(f"- 🔹 **Input:** {item['text']}")
st.markdown(f" - ✨ **Translated:** {item['translated']}")
st.markdown(
f" - ❗ **Label:** {item['label']} with **{item['score']:.2%}** confidence"
)
st.markdown(f" - 🔧 **Suggestion:** {item['reason']} ")
# 雷达图
radar_df = pd.DataFrame({
"Category": ["Insult", "Abuse", "Discrimination", "Hate Speech", "Vulgarity"],
"Score": [0.7, 0.4, 0.3, 0.5, 0.6]
})
radar_fig = px.line_polar(
radar_df,
r='Score',
theta='Category',
line_close=True,
title="⚠️ Risk Radar by Category",
color_discrete_sequence=['black'],
template='simple_white'
)
radar_fig.update_layout(
polar=dict(
gridshape='circular',
bgcolor='white',
radialaxis=dict(
showticklabels=False,
ticks='',
showgrid=True,
gridcolor='lightgrey',
gridwidth=1,
linecolor='black',
linewidth=2
),
angularaxis=dict(
showticklabels=False,
ticks='',
showline=True,
linecolor='black',
linewidth=2
)
),
paper_bgcolor='white',
plot_bgcolor='white'
)
st.plotly_chart(radar_fig)
else:
st.info("⚠️ No classification data available yet.")