Sleepyriizi's picture
Update app.py
decced4 verified
raw
history blame
7.59 kB
"""
Orify Text Detector – Space edition (Zero-GPU ready)
β€’ Three ModernBERT-base checkpoints (soft-vote)
β€’ Per-line colour coding, probability tool-tips, top-3 AI model hints
β€’ Everything fetched automatically from the weight repo and cached
"""
# ── Imports ──────────────────────────────────────────────────────────────
from pathlib import Path
import re, torch, gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from huggingface_hub import hf_hub_download
import spaces
import os, types # add `types`
# ────────────────── robust torch.compile shim ─────────────────────────
if hasattr(torch, "compile"):
def _no_compile(model: types.Any = None, *args, **kwargs):
"""
1. If called as torch.compile(model, …) β†’ just return the model.
2. If called as torch.compile(**kw) β†’ return a decorator that
immediately gives back the class / fn it decorates.
"""
if callable(model): # pattern 1
return model
# pattern 2 (used by ModernBERT via @torch.compile(...))
def decorator(fn):
return fn
return decorator
torch.compile = _no_compile # monkey-patch
os.environ["TORCHINDUCTOR_DISABLED"] = "1"
# (everything below is unchanged)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
FILE_MAP = {"ensamble_1":"ensamble_1",
"ensamble_2.bin":"ensamble_2.bin",
"ensamble_3":"ensamble_3"}
BASE_MODEL_NAME = "answerdotai/ModernBERT-base"
NUM_LABELS = 41
LABELS = { # id β†’ friendly label (unchanged)
0: "13B", 1: "30B", 2: "65B", 3: "7B", 4: "GLM130B",
5: "bloom_7b", 6: "bloomz", 7: "cohere", 8: "davinci",
9: "dolly", 10: "dolly-v2-12b", 11: "flan_t5_base",
12: "flan_t5_large", 13: "flan_t5_small", 14: "flan_t5_xl",
15: "flan_t5_xxl", 16: "gemma-7b-it", 17: "gemma2-9b-it",
18: "gpt-3.5-turbo", 19: "gpt-35", 20: "gpt-4",
21: "gpt-4o", 22: "gpt-j", 23: "gpt-neox", 24: "human",
25: "llama3-70b", 26: "llama3-8b", 27: "mixtral-8x7b",
28: "opt-1.3b", 29: "opt-125m", 30: "opt-13b",
31: "opt-2.7b", 32: "opt-30b", 33: "opt-350m",
34: "opt-6.7b", 35: "opt-iml-30b", 36: "opt-iml-max-1.3b",
37: "t0-11b", 38: "t0-3b", 39: "text-davinci-002", 40: "text-davinci-003"
}
# ── CSS (kept identical) ────────────────────────────────────────────────
CSS = Path(__file__).with_name("style.css").read_text() if Path(__file__).with_name("style.css").exists() else """
:root{--clr-ai:#ff4d4f;--clr-human:#52c41a;--border:2px solid var(--clr-ai);--radius:10px}
body{font-family:'Roboto Mono',monospace;margin:0 auto;max-width:900px;padding:32px}
textarea,.output-box{width:100%;box-sizing:border-box;padding:16px;font-size:1rem;border:var(--border);border-radius:var(--radius)}
.output-box{min-height:160px}.ai-line{background:rgba(255,77,79,.12);padding:2px 4px;border-radius:4px}
.human-line{background:rgba(82,196,26,.12);padding:2px 4px;border-radius:4px}
.prob-tooltip{cursor:help;border-bottom:1px dotted currentColor}
"""
# ── Model loading (download once, then cached) ───────────────────────────
print("πŸ”„ Downloading weights …")
local_paths = {alias: hf_hub_download(WEIGHT_REPO, fname, resume_download=True)
for alias, fname in FILE_MAP.items()}
print("🧩 Loading tokenizer & models …")
tokeniser = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
models = []
for alias, path in local_paths.items():
net = AutoModelForSequenceClassification.from_pretrained(
BASE_MODEL_NAME, num_labels=NUM_LABELS)
net.load_state_dict(torch.load(path, map_location=DEVICE))
net.to(DEVICE).eval()
models.append(net)
# ── Helpers ──────────────────────────────────────────────────────────────
def tidy(txt: str) -> str:
txt = txt.replace("\r\n", "\n").replace("\r", "\n")
txt = re.sub(r"\n\s*\n+", "\n\n", txt)
txt = re.sub(r"[ \t]+", " ", txt)
txt = re.sub(r"(\w+)-\n(\w+)", r"\1\2", txt)
txt = re.sub(r"(?<!\n)\n(?!\n)", " ", txt)
return txt.strip()
def infer(segment: str):
"""Return (human%, ai%, [top-3 ai model names])."""
inputs = tokeniser(segment, return_tensors="pt", truncation=True,
padding=True).to(DEVICE)
with torch.no_grad():
probs = torch.stack([
torch.softmax(m(**inputs).logits, dim=1) for m in models
]).mean(dim=0)[0]
ai_probs = probs.clone(); ai_probs[24] = 0 # null out human idx
ai_score = ai_probs.sum().item() * 100
human_score = 100 - ai_score
top3 = torch.topk(ai_probs, 3).indices.tolist()
top3_names = [LABELS[i] for i in top3]
return human_score, ai_score, top3_names
# ── Inference + explanation ──────────────────────────────────────────────
@spaces.GPU
def analyse(text: str):
if not text.strip():
return "✏️ Please paste or type some text to analyse…"
lines = tidy(text).split("\n")
highlighted, h_tot, ai_tot, n = [], 0.0, 0.0, 0
for ln in lines:
if not ln.strip():
highlighted.append("<br>")
continue
n += 1
h, ai, top3 = infer(ln)
h_tot += h; ai_tot += ai
tooltip = (f"AI {ai:.2f}% β€’ Top-3: {', '.join(top3)}"
if ai > h else f"Human {h:.2f}%")
cls = "ai-line" if ai > h else "human-line"
span = (f"<span class='{cls} prob-tooltip' title='{tooltip}'>"
f"{gr.utils.sanitize_html(ln)}</span>")
highlighted.append(span)
verdict = (f"<p><strong>Overall verdict:</strong> "
f"<span class='human-line' style='padding:4px 8px;'>"
f"Human-written {h_tot/n:.2f}%</span>"
if h_tot >= ai_tot else
f"<p><strong>Overall verdict:</strong> "
f"<span class='ai-line' style='padding:4px 8px;'>"
f"AI-generated {ai_tot/n:.2f}%</span>")
return verdict + "<hr>" + "<br>".join(highlighted)
# ── Interface ────────────────────────────────────────────────────────────
with gr.Blocks(css=CSS, title="Orify Text Detector") as demo:
gr.Markdown("""
### Orify Text Detector
Paste any English text and press **Analyse**.
<span class='human-line'>Green</span> = human | <span class='ai-line'>Red</span> = AI.
Hover a line to see confidence and the top-3 AI models it resembles.
""")
inp = gr.Textbox(lines=8, placeholder="Paste text here …",
elem_classes=["input-area"])
out = gr.HTML("", elem_classes=["output-box"])
gr.Button("Analyse").click(analyse, inp, out)
gr.Markdown("<sub>Powered by ModernBERT + Orify Ensemble Β© 2025</sub>")
if __name__ == "__main__":
demo.launch()