test_1 / app.py
aeresd's picture
Update app.py
a8b7aaa verified
raw
history blame
5.67 kB
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
import streamlit as st
from PIL import Image
import pytesseract
import pandas as pd
import plotly.express as px
# โœ… Step 1: Emoji ็ฟป่ฏ‘ๆจกๅž‹๏ผˆไฝ ่‡ชๅทฑ่ฎญ็ปƒ็š„ๆจกๅž‹๏ผ‰
emoji_model_id = "JenniferHJF/qwen1.5-emoji-finetuned"
emoji_tokenizer = AutoTokenizer.from_pretrained(emoji_model_id, trust_remote_code=True)
emoji_model = AutoModelForCausalLM.from_pretrained(
emoji_model_id,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to("cuda" if torch.cuda.is_available() else "cpu")
emoji_model.eval()
# โœ… Step 2: ๅฏ้€‰ๆ‹ฉ็š„ๅ†’็Šฏๆ€งๆ–‡ๆœฌ่ฏ†ๅˆซๆจกๅž‹
model_options = {
"Toxic-BERT": "unitary/toxic-bert",
"Roberta Offensive": "cardiffnlp/twitter-roberta-base-offensive",
"BERT Emotion": "bhadresh-savani/bert-base-go-emotion"
}
# โœ… ้กต้ข้…็ฝฎ
st.set_page_config(page_title="Emoji Offensive Text Detector", page_icon="๐Ÿšจ", layout="wide")
# โœ… ้กต้ขๅธƒๅฑ€
with st.sidebar:
st.header("๐Ÿง  Navigation")
section = st.radio("Select Mode:", ["๐Ÿ“ Text Moderation", "๐Ÿ“Š Text Analysis"])
if section == "๐Ÿ“ Text Moderation":
selected_model = st.selectbox("Choose classification model", list(model_options.keys()))
selected_model_id = model_options[selected_model]
classifier = pipeline("text-classification", model=selected_model_id, device=0 if torch.cuda.is_available() else -1)
elif section == "๐Ÿ“Š Text Analysis":
st.markdown("You can view the violation distribution chart and editing suggestions.")
if "history" not in st.session_state:
st.session_state.history = []
def classify_emoji_text(text: str):
prompt = f"่พ“ๅ…ฅ๏ผš{text}\n่พ“ๅ‡บ๏ผš"
input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
with torch.no_grad():
output_ids = emoji_model.generate(**input_ids, max_new_tokens=64, do_sample=False)
decoded = emoji_tokenizer.decode(output_ids[0], skip_special_tokens=True)
translated_text = decoded.split("่พ“ๅ‡บ๏ผš")[-1].strip() if "่พ“ๅ‡บ๏ผš" in decoded else decoded.strip()
result = classifier(translated_text)[0]
label = result["label"]
score = result["score"]
reasoning = f"The sentence was flagged as '{label}' due to potentially offensive phrases. Consider replacing emotionally charged, ambiguous, or abusive terms."
st.session_state.history.append({"text": text, "translated": translated_text, "label": label, "score": score, "reason": reasoning})
return translated_text, label, score, reasoning
# โœ… Section logic
if section == "๐Ÿ“ Text Moderation":
st.title("๐Ÿ“ Offensive Text Classification")
st.markdown("### โœ๏ธ Input your sentence:")
default_text = "ไฝ ๆ˜ฏ๐Ÿท"
text = st.text_area("Enter sentence with emojis:", value=default_text, height=150)
if st.button("๐Ÿšฆ Analyze"):
with st.spinner("๐Ÿ” Processing..."):
try:
translated, label, score, reason = classify_emoji_text(text)
st.markdown("### ๐Ÿ”„ Translated sentence:")
st.code(translated, language="text")
st.markdown(f"### ๐ŸŽฏ Prediction: {label}")
st.markdown(f"### ๐Ÿ“Š Confidence Score: {score:.2%}")
st.markdown(f"### ๐Ÿง  Model Explanation:")
st.info(reason)
except Exception as e:
st.error(f"โŒ An error occurred during processing:\n\n{e}")
st.markdown("---")
st.markdown("### ๐Ÿ–ผ๏ธ Or upload a screenshot of bullet comments:")
uploaded_file = st.file_uploader("Upload an image (JPG/PNG)", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Screenshot", use_column_width=True)
with st.spinner("๐Ÿง  Extracting text via OCR..."):
ocr_text = pytesseract.image_to_string(image, lang="chi_sim+eng")
st.markdown("#### ๐Ÿ“‹ Extracted Text:")
st.code(ocr_text.strip())
translated, label, score, reason = classify_emoji_text(ocr_text.strip())
st.markdown("### ๐Ÿ”„ Translated sentence:")
st.code(translated, language="text")
st.markdown(f"### ๐ŸŽฏ Prediction: {label}")
st.markdown(f"### ๐Ÿ“Š Confidence Score: {score:.2%}")
st.markdown("### ๐Ÿง  Model Explanation:")
st.info(reason)
elif section == "๐Ÿ“Š Text Analysis":
st.title("๐Ÿ“Š Violation Analysis Dashboard")
if st.session_state.history:
df = pd.DataFrame(st.session_state.history)
# ๅทฒ็งป้™ค Offensive Category Distribution ้ฅผๅ›พ
st.markdown("### ๐Ÿงพ Offensive Terms & Suggestions")
for item in st.session_state.history:
st.markdown(f"- ๐Ÿ”น **Input:** {item['text']}")
st.markdown(f" - โœจ **Translated:** {item['translated']}")
st.markdown(f" - โ— **Label:** {item['label']} with **{item['score']:.2%}** confidence")
st.markdown(f" - ๐Ÿ”ง **Suggestion:** {item['reason']}")
radar_df = pd.DataFrame({
"Category": ["Insult", "Abuse", "Discrimination", "Hate Speech", "Vulgarity"],
"Score": [0.7, 0.4, 0.3, 0.5, 0.6]
})
radar_fig = px.line_polar(radar_df, r='Score', theta='Category', line_close=True, title="โš ๏ธ Risk Radar by Category")
radar_fig.update_traces(line_color='black') # ๅฐ†้›ท่พพๅ›พ็บฟๆกๆ”นไธบ้ป‘่‰ฒ
st.plotly_chart(radar_fig)
else:
st.info("โš ๏ธ No classification data available yet.")