aeresd commited on
Commit
5a8b969
·
verified ·
1 Parent(s): 88261ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -17
app.py CHANGED
@@ -1,5 +1,6 @@
1
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
2
  import torch
 
3
 
4
  # ✅ Step 1: Emoji 翻译模型(你自己训练的模型)
5
  emoji_model_id = "JenniferHJF/qwen1.5-emoji-finetuned"
@@ -7,33 +8,66 @@ emoji_tokenizer = AutoTokenizer.from_pretrained(emoji_model_id, trust_remote_cod
7
  emoji_model = AutoModelForCausalLM.from_pretrained(
8
  emoji_model_id,
9
  trust_remote_code=True,
10
- torch_dtype=torch.float16
11
  ).to("cuda" if torch.cuda.is_available() else "cpu")
12
  emoji_model.eval()
13
 
14
- # ✅ Step 2: 冒犯性文本识别模型
15
- classifier = pipeline("text-classification", model="unitary/toxic-bert", device=0 if torch.cuda.is_available() else -1)
 
 
 
 
16
 
17
- def classify_emoji_text(text: str):
18
- """
19
- Step 1: 翻译文本中的 emoji
20
- Step 2: 使用分类器判断是否冒犯
21
- """
22
- prompt = f"""请判断下面的文本是否具有冒犯性。
23
- 这里的“冒犯性”主要指包含人身攻击、侮辱、歧视、仇恨言论或极端粗俗的内容。
24
- 如果文本具有冒犯性,请仅回复冒犯;如果不具有冒犯性,请仅回复不冒犯。
25
- 文本如下:
26
- {text}
27
- """
 
 
 
 
 
 
 
 
 
 
28
 
 
 
 
29
  input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
30
  with torch.no_grad():
31
- output_ids = emoji_model.generate(**input_ids, max_new_tokens=50, do_sample=False)
32
  decoded = emoji_tokenizer.decode(output_ids[0], skip_special_tokens=True)
33
- translated_text = decoded.strip().split("文本如下:")[-1].strip()
34
 
35
  result = classifier(translated_text)[0]
36
  label = result["label"]
37
  score = result["score"]
38
 
39
- return translated_text, label, score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
2
  import torch
3
+ import streamlit as st
4
 
5
  # ✅ Step 1: Emoji 翻译模型(你自己训练的模型)
6
  emoji_model_id = "JenniferHJF/qwen1.5-emoji-finetuned"
 
8
  emoji_model = AutoModelForCausalLM.from_pretrained(
9
  emoji_model_id,
10
  trust_remote_code=True,
11
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
12
  ).to("cuda" if torch.cuda.is_available() else "cpu")
13
  emoji_model.eval()
14
 
15
+ # ✅ Step 2: 可选择的冒犯性文本识别模型
16
+ model_options = {
17
+ "Toxic-BERT": "unitary/toxic-bert",
18
+ "Roberta Offensive": "cardiffnlp/twitter-roberta-base-offensive",
19
+ "BERT Emotion": "bhadresh-savani/bert-base-go-emotion"
20
+ }
21
 
22
+ # 页面配置
23
+ st.set_page_config(page_title="Emoji Offensive Text Detector", page_icon="🚨", layout="wide")
24
+
25
+ # 页面标题
26
+ st.title("🧠 Emoji-based Offensive Language Classifier")
27
+
28
+ st.markdown("""
29
+ This application translates emojis in a sentence and classifies whether the final sentence is offensive or not using two AI models.
30
+ - The **first model** translates emoji or symbolic phrases into standard Chinese text.
31
+ - The **second model** performs offensive language detection.
32
+ """)
33
+
34
+ # Streamlit 侧边栏模型选择
35
+ selected_model = st.sidebar.selectbox("Choose classification model", list(model_options.keys()))
36
+ selected_model_id = model_options[selected_model]
37
+ classifier = pipeline("text-classification", model=selected_model_id, device=0 if torch.cuda.is_available() else -1)
38
+
39
+ # ✅ 输入区域
40
+ st.markdown("### ✍️ Input your sentence:")
41
+ default_text = "你是🐷"
42
+ text = st.text_area("Enter sentence with emojis:", value=default_text, height=150)
43
 
44
+ # ✅ 主逻辑封装函数
45
+ def classify_emoji_text(text: str):
46
+ prompt = f"输入:{text}\n输出:"
47
  input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
48
  with torch.no_grad():
49
+ output_ids = emoji_model.generate(**input_ids, max_new_tokens=64, do_sample=False)
50
  decoded = emoji_tokenizer.decode(output_ids[0], skip_special_tokens=True)
51
+ translated_text = decoded.split("输出:")[-1].strip() if "输出:" in decoded else decoded.strip()
52
 
53
  result = classifier(translated_text)[0]
54
  label = result["label"]
55
  score = result["score"]
56
 
57
+ return translated_text, label, score
58
+
59
+ # ✅ 触发按钮
60
+ if st.button("🚦 Analyze"):
61
+ with st.spinner("🔍 Processing..."):
62
+ try:
63
+ translated, label, score = classify_emoji_text(text)
64
+ st.markdown("### 🔄 Translated sentence:")
65
+ st.code(translated, language="text")
66
+
67
+ st.markdown(f"### 🎯 Prediction: `{label}`")
68
+ st.markdown(f"### 📊 Confidence Score: `{score:.2%}`")
69
+
70
+ except Exception as e:
71
+ st.error(f"❌ An error occurred during processing:\n\n{e}")
72
+ else:
73
+ st.info("👈 Please input text and click the button to classify.")