from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM import torch import streamlit as st from PIL import Image import pytesseract import pandas as pd import plotly.express as px # ✅ Step 1: Emoji 翻译模型(你自己训练的模型) emoji_model_id = "JenniferHJF/qwen1.5-emoji-finetuned" emoji_tokenizer = AutoTokenizer.from_pretrained(emoji_model_id, trust_remote_code=True) emoji_model = AutoModelForCausalLM.from_pretrained( emoji_model_id, trust_remote_code=True, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ).to("cuda" if torch.cuda.is_available() else "cpu") emoji_model.eval() # ✅ Step 2: 可选择的冒犯性文本识别模型 model_options = { "Toxic-BERT": "unitary/toxic-bert", "Roberta Offensive": "cardiffnlp/twitter-roberta-base-offensive", "BERT Emotion": "bhadresh-savani/bert-base-go-emotion" } # ✅ 页面配置 st.set_page_config(page_title="Emoji Offensive Text Detector", page_icon="🚨", layout="wide") # ✅ 侧边栏: 选择模型 with st.sidebar: st.header("🧠 Settings") selected_model = st.selectbox("Choose classification model", list(model_options.keys())) selected_model_id = model_options[selected_model] classifier = pipeline("text-classification", model=selected_model_id, device=0 if torch.cuda.is_available() else -1) # 初始化历史记录 if "history" not in st.session_state: st.session_state.history = [] # 核心函数: 翻译并分类 def classify_emoji_text(text: str): prompt = f"输入:{text}\n输出:" input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device) with torch.no_grad(): output_ids = emoji_model.generate(**input_ids, max_new_tokens=64, do_sample=False) decoded = emoji_tokenizer.decode(output_ids[0], skip_special_tokens=True) translated_text = decoded.split("输出:")[-1].strip() if "输出:" in decoded else decoded.strip() result = classifier(translated_text)[0] label = result["label"] score = result["score"] reasoning = ( f"The sentence was flagged as '{label}' due to potentially offensive phrases. " "Consider replacing emotionally charged, ambiguous, or abusive terms." ) st.session_state.history.append({ "text": text, "translated": translated_text, "label": label, "score": score, "reason": reasoning }) return translated_text, label, score, reasoning # 页面主体 st.title("🚨 Emoji Offensive Text Detector & Analysis") # 输入区域 st.markdown("### ✍️ Input your sentence or upload screenshot:") col1, col2 = st.columns(2) with col1: default_text = "你是🐷" text = st.text_area("Enter sentence with emojis:", value=default_text, height=150) if st.button("🚦 Analyze Text"): with st.spinner("🔍 Processing..."): try: translated, label, score, reason = classify_emoji_text(text) st.markdown("#### 🔄 Translated sentence:") st.code(translated, language="text") st.markdown(f"#### 🎯 Prediction: {label}") st.markdown(f"#### 📊 Confidence Score: {score:.2%}") st.markdown("#### 🧠 Model Explanation:") st.info(reason) except Exception as e: st.error(f"❌ An error occurred during processing:\n\n{e}") with col2: uploaded_file = st.file_uploader("Upload an image (JPG/PNG)", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption="Uploaded Screenshot", use_column_width=True) if st.button("🛠️ OCR & Analyze Image"): with st.spinner("🧠 Extracting text via OCR..."): ocr_text = pytesseract.image_to_string(image, lang="chi_sim+eng").strip() st.markdown("#### 📋 Extracted Text:") st.code(ocr_text) classify_emoji_text(ocr_text) # 分析仪表盘 st.markdown("---") st.title("📊 Violation Analysis Dashboard") if st.session_state.history: st.markdown("### 🧾 Offensive Terms & Suggestions") for item in st.session_state.history: st.markdown(f"- 🔹 **Input:** {item['text']}") st.markdown(f" - ✨ **Translated:** {item['translated']}") st.markdown(f" - ❗ **Label:** {item['label']} with **{item['score']:.2%}** confidence") st.markdown(f" - 🔧 **Suggestion:** {item['reason']}") # 雷达图演示示例(可替换为动态数据) radar_df = pd.DataFrame({ "Category": ["Insult", "Abuse", "Discrimination", "Hate Speech", "Vulgarity"], "Score": [0.7, 0.4, 0.3, 0.5, 0.6] }) radar_fig = px.line_polar( radar_df, r='Score', theta='Category', line_close=True, title="⚠️ Risk Radar by Category", color_discrete_sequence=['black'], template='simple_white' ) radar_fig.update_layout( polar=dict( gridshape='circular', bgcolor='white', radialaxis=dict( showticklabels=False, ticks='', showgrid=True, gridcolor='lightgrey', gridwidth=1, linecolor='black', linewidth=2 ), angularaxis=dict( showticklabels=False, ticks='', showline=True, linecolor='black', linewidth=2 ) ), paper_bgcolor='white', plot_bgcolor='white' ) st.plotly_chart(radar_fig) else: st.info("⚠️ No classification data available yet.")