Update app.py
Browse files
app.py
CHANGED
@@ -3,53 +3,51 @@
|
|
3 |
|
4 |
β’ Three ModernBERT-base checkpoints (soft-vote)
|
5 |
β’ Per-line colour coding, probability tool-tips, top-3 AI model hints
|
6 |
-
β’
|
7 |
"""
|
8 |
|
9 |
# ββ Imports ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
10 |
from pathlib import Path
|
11 |
-
import re, os, torch, gradio as gr
|
12 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
import spaces
|
15 |
-
import
|
16 |
|
17 |
-
# ββββββββββββββββββ robust torch.compile shim
|
18 |
if hasattr(torch, "compile"):
|
19 |
-
def _no_compile(model:
|
20 |
"""
|
21 |
-
1. torch.compile(model, β¦)
|
22 |
-
2. torch.compile(**kw)
|
23 |
-
immediately gives back the class/
|
24 |
"""
|
25 |
-
if callable(model):
|
26 |
return model
|
27 |
-
|
28 |
-
def decorator(fn):
|
29 |
return fn
|
30 |
return decorator
|
31 |
|
32 |
-
torch.compile = _no_compile
|
33 |
os.environ["TORCHINDUCTOR_DISABLED"] = "1"
|
34 |
|
35 |
-
#
|
36 |
-
DEVICE
|
37 |
-
WEIGHT_REPO
|
38 |
-
FILE_MAP
|
39 |
-
|
40 |
-
|
41 |
-
"ensamble_3": "ensamble_3",
|
42 |
-
}
|
43 |
-
|
44 |
BASE_MODEL_NAME = "answerdotai/ModernBERT-base"
|
45 |
NUM_LABELS = 41
|
46 |
|
47 |
-
LABELS = {
|
48 |
-
0: "13B", 1: "30B",
|
49 |
-
6: "bloomz", 7: "cohere", 8: "davinci",
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
53 |
21: "gpt-4o", 22: "gpt-j", 23: "gpt-neox", 24: "human",
|
54 |
25: "llama3-70b", 26: "llama3-8b", 27: "mixtral-8x7b",
|
55 |
28: "opt-1.3b", 29: "opt-125m", 30: "opt-13b",
|
@@ -58,7 +56,7 @@ LABELS = { # id β friendly label
|
|
58 |
37: "t0-11b", 38: "t0-3b", 39: "text-davinci-002", 40: "text-davinci-003"
|
59 |
}
|
60 |
|
61 |
-
# ββ CSS (
|
62 |
CSS = Path(__file__).with_name("style.css").read_text() if Path(__file__).with_name("style.css").exists() else """
|
63 |
:root{--clr-ai:#ff4d4f;--clr-human:#52c41a;--border:2px solid var(--clr-ai);--radius:10px}
|
64 |
body{font-family:'Roboto Mono',monospace;margin:0 auto;max-width:900px;padding:32px}
|
@@ -77,7 +75,7 @@ print("π§© Loading tokenizer & models β¦")
|
|
77 |
tokeniser = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
|
78 |
|
79 |
models = []
|
80 |
-
for
|
81 |
net = AutoModelForSequenceClassification.from_pretrained(
|
82 |
BASE_MODEL_NAME, num_labels=NUM_LABELS)
|
83 |
net.load_state_dict(torch.load(path, map_location=DEVICE))
|
@@ -94,19 +92,20 @@ def tidy(txt: str) -> str:
|
|
94 |
return txt.strip()
|
95 |
|
96 |
def infer(segment: str):
|
97 |
-
"""Return (human%, ai%,
|
98 |
-
inputs = tokeniser(segment, return_tensors="pt",
|
99 |
-
|
100 |
with torch.no_grad():
|
101 |
probs = torch.stack([
|
102 |
torch.softmax(m(**inputs).logits, dim=1) for m in models
|
103 |
]).mean(dim=0)[0]
|
104 |
|
105 |
-
ai_probs = probs.clone(); ai_probs[24] = 0
|
106 |
ai_score = ai_probs.sum().item() * 100
|
107 |
human_score = 100 - ai_score
|
108 |
top3 = torch.topk(ai_probs, 3).indices.tolist()
|
109 |
-
|
|
|
110 |
|
111 |
# ββ Inference + explanation ββββββββββββββββββββββββββββββββββββββββββββββ
|
112 |
@spaces.GPU
|
@@ -126,9 +125,9 @@ def analyse(text: str):
|
|
126 |
h_tot += h; ai_tot += ai
|
127 |
tooltip = (f"AI {ai:.2f}% β’ Top-3: {', '.join(top3)}"
|
128 |
if ai > h else f"Human {h:.2f}%")
|
129 |
-
cls
|
130 |
span = (f"<span class='{cls} prob-tooltip' title='{tooltip}'>"
|
131 |
-
f"{
|
132 |
highlighted.append(span)
|
133 |
|
134 |
verdict = (f"<p><strong>Overall verdict:</strong> "
|
@@ -140,7 +139,7 @@ def analyse(text: str):
|
|
140 |
f"AI-generated {ai_tot/n:.2f}%</span>")
|
141 |
return verdict + "<hr>" + "<br>".join(highlighted)
|
142 |
|
143 |
-
# ββ
|
144 |
with gr.Blocks(css=CSS, title="Orify Text Detector") as demo:
|
145 |
gr.Markdown("""
|
146 |
### Orify Text Detector
|
|
|
3 |
|
4 |
β’ Three ModernBERT-base checkpoints (soft-vote)
|
5 |
β’ Per-line colour coding, probability tool-tips, top-3 AI model hints
|
6 |
+
β’ Everything fetched automatically from the weight repo and cached
|
7 |
"""
|
8 |
|
9 |
# ββ Imports ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
10 |
from pathlib import Path
|
11 |
+
import re, os, html, torch, gradio as gr # β add html
|
12 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
import spaces
|
15 |
+
import os, types # add `types`
|
16 |
|
17 |
+
# ββββββββββββββββββ robust torch.compile shim βββββββββββββββββββββββββ
|
18 |
if hasattr(torch, "compile"):
|
19 |
+
def _no_compile(model: types.Any = None, *args, **kwargs):
|
20 |
"""
|
21 |
+
1. If called as torch.compile(model, β¦) β just return the model.
|
22 |
+
2. If called as torch.compile(**kw) β return a decorator that
|
23 |
+
immediately gives back the class / fn it decorates.
|
24 |
"""
|
25 |
+
if callable(model): # pattern 1
|
26 |
return model
|
27 |
+
# pattern 2 (used by ModernBERT via @torch.compile(...))
|
28 |
+
def decorator(fn):
|
29 |
return fn
|
30 |
return decorator
|
31 |
|
32 |
+
torch.compile = _no_compile # monkey-patch
|
33 |
os.environ["TORCHINDUCTOR_DISABLED"] = "1"
|
34 |
|
35 |
+
# (everything below is unchanged)
|
36 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
37 |
+
WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
|
38 |
+
FILE_MAP = {"ensamble_1":"ensamble_1",
|
39 |
+
"ensamble_2.bin":"ensamble_2.bin",
|
40 |
+
"ensamble_3":"ensamble_3"}
|
|
|
|
|
|
|
41 |
BASE_MODEL_NAME = "answerdotai/ModernBERT-base"
|
42 |
NUM_LABELS = 41
|
43 |
|
44 |
+
LABELS = { # id β friendly label (unchanged)
|
45 |
+
0: "13B", 1: "30B", 2: "65B", 3: "7B", 4: "GLM130B",
|
46 |
+
5: "bloom_7b", 6: "bloomz", 7: "cohere", 8: "davinci",
|
47 |
+
9: "dolly", 10: "dolly-v2-12b", 11: "flan_t5_base",
|
48 |
+
12: "flan_t5_large", 13: "flan_t5_small", 14: "flan_t5_xl",
|
49 |
+
15: "flan_t5_xxl", 16: "gemma-7b-it", 17: "gemma2-9b-it",
|
50 |
+
18: "gpt-3.5-turbo", 19: "gpt-35", 20: "gpt-4",
|
51 |
21: "gpt-4o", 22: "gpt-j", 23: "gpt-neox", 24: "human",
|
52 |
25: "llama3-70b", 26: "llama3-8b", 27: "mixtral-8x7b",
|
53 |
28: "opt-1.3b", 29: "opt-125m", 30: "opt-13b",
|
|
|
56 |
37: "t0-11b", 38: "t0-3b", 39: "text-davinci-002", 40: "text-davinci-003"
|
57 |
}
|
58 |
|
59 |
+
# ββ CSS (kept identical) ββββββββββββββββββββββββββββββββββββββββββββββββ
|
60 |
CSS = Path(__file__).with_name("style.css").read_text() if Path(__file__).with_name("style.css").exists() else """
|
61 |
:root{--clr-ai:#ff4d4f;--clr-human:#52c41a;--border:2px solid var(--clr-ai);--radius:10px}
|
62 |
body{font-family:'Roboto Mono',monospace;margin:0 auto;max-width:900px;padding:32px}
|
|
|
75 |
tokeniser = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
|
76 |
|
77 |
models = []
|
78 |
+
for alias, path in local_paths.items():
|
79 |
net = AutoModelForSequenceClassification.from_pretrained(
|
80 |
BASE_MODEL_NAME, num_labels=NUM_LABELS)
|
81 |
net.load_state_dict(torch.load(path, map_location=DEVICE))
|
|
|
92 |
return txt.strip()
|
93 |
|
94 |
def infer(segment: str):
|
95 |
+
"""Return (human%, ai%, [top-3 ai model names])."""
|
96 |
+
inputs = tokeniser(segment, return_tensors="pt", truncation=True,
|
97 |
+
padding=True).to(DEVICE)
|
98 |
with torch.no_grad():
|
99 |
probs = torch.stack([
|
100 |
torch.softmax(m(**inputs).logits, dim=1) for m in models
|
101 |
]).mean(dim=0)[0]
|
102 |
|
103 |
+
ai_probs = probs.clone(); ai_probs[24] = 0 # null out human idx
|
104 |
ai_score = ai_probs.sum().item() * 100
|
105 |
human_score = 100 - ai_score
|
106 |
top3 = torch.topk(ai_probs, 3).indices.tolist()
|
107 |
+
top3_names = [LABELS[i] for i in top3]
|
108 |
+
return human_score, ai_score, top3_names
|
109 |
|
110 |
# ββ Inference + explanation ββββββββββββββββββββββββββββββββββββββββββββββ
|
111 |
@spaces.GPU
|
|
|
125 |
h_tot += h; ai_tot += ai
|
126 |
tooltip = (f"AI {ai:.2f}% β’ Top-3: {', '.join(top3)}"
|
127 |
if ai > h else f"Human {h:.2f}%")
|
128 |
+
cls = "ai-line" if ai > h else "human-line"
|
129 |
span = (f"<span class='{cls} prob-tooltip' title='{tooltip}'>"
|
130 |
+
f"{html.escape(ln)}</span>") # β use html.escape
|
131 |
highlighted.append(span)
|
132 |
|
133 |
verdict = (f"<p><strong>Overall verdict:</strong> "
|
|
|
139 |
f"AI-generated {ai_tot/n:.2f}%</span>")
|
140 |
return verdict + "<hr>" + "<br>".join(highlighted)
|
141 |
|
142 |
+
# ββ Interface ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
143 |
with gr.Blocks(css=CSS, title="Orify Text Detector") as demo:
|
144 |
gr.Markdown("""
|
145 |
### Orify Text Detector
|