Sleepyriizi commited on
Commit
2a089a8
Β·
verified Β·
1 Parent(s): decced4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -34
app.py CHANGED
@@ -3,51 +3,53 @@
3
 
4
  β€’ Three ModernBERT-base checkpoints (soft-vote)
5
  β€’ Per-line colour coding, probability tool-tips, top-3 AI model hints
6
- β€’ Everything fetched automatically from the weight repo and cached
7
  """
8
 
9
  # ── Imports ──────────────────────────────────────────────────────────────
10
  from pathlib import Path
11
- import re, torch, gradio as gr
12
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
13
  from huggingface_hub import hf_hub_download
14
  import spaces
15
- import os, types # add `types`
16
 
17
- # ────────────────── robust torch.compile shim ─────────────────────────
18
  if hasattr(torch, "compile"):
19
- def _no_compile(model: types.Any = None, *args, **kwargs):
20
  """
21
- 1. If called as torch.compile(model, …) β†’ just return the model.
22
- 2. If called as torch.compile(**kw) β†’ return a decorator that
23
- immediately gives back the class / fn it decorates.
24
  """
25
- if callable(model): # pattern 1
26
  return model
27
- # pattern 2 (used by ModernBERT via @torch.compile(...))
28
- def decorator(fn):
29
  return fn
30
  return decorator
31
 
32
- torch.compile = _no_compile # monkey-patch
33
  os.environ["TORCHINDUCTOR_DISABLED"] = "1"
34
 
35
- # (everything below is unchanged)
36
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
- WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
38
- FILE_MAP = {"ensamble_1":"ensamble_1",
39
- "ensamble_2.bin":"ensamble_2.bin",
40
- "ensamble_3":"ensamble_3"}
 
 
 
41
  BASE_MODEL_NAME = "answerdotai/ModernBERT-base"
42
  NUM_LABELS = 41
43
 
44
- LABELS = { # id β†’ friendly label (unchanged)
45
- 0: "13B", 1: "30B", 2: "65B", 3: "7B", 4: "GLM130B",
46
- 5: "bloom_7b", 6: "bloomz", 7: "cohere", 8: "davinci",
47
- 9: "dolly", 10: "dolly-v2-12b", 11: "flan_t5_base",
48
- 12: "flan_t5_large", 13: "flan_t5_small", 14: "flan_t5_xl",
49
- 15: "flan_t5_xxl", 16: "gemma-7b-it", 17: "gemma2-9b-it",
50
- 18: "gpt-3.5-turbo", 19: "gpt-35", 20: "gpt-4",
51
  21: "gpt-4o", 22: "gpt-j", 23: "gpt-neox", 24: "human",
52
  25: "llama3-70b", 26: "llama3-8b", 27: "mixtral-8x7b",
53
  28: "opt-1.3b", 29: "opt-125m", 30: "opt-13b",
@@ -56,7 +58,7 @@ LABELS = { # id β†’ friendly label (unchanged)
56
  37: "t0-11b", 38: "t0-3b", 39: "text-davinci-002", 40: "text-davinci-003"
57
  }
58
 
59
- # ── CSS (kept identical) ────────────────────────────────────────────────
60
  CSS = Path(__file__).with_name("style.css").read_text() if Path(__file__).with_name("style.css").exists() else """
61
  :root{--clr-ai:#ff4d4f;--clr-human:#52c41a;--border:2px solid var(--clr-ai);--radius:10px}
62
  body{font-family:'Roboto Mono',monospace;margin:0 auto;max-width:900px;padding:32px}
@@ -75,7 +77,7 @@ print("🧩 Loading tokenizer & models …")
75
  tokeniser = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
76
 
77
  models = []
78
- for alias, path in local_paths.items():
79
  net = AutoModelForSequenceClassification.from_pretrained(
80
  BASE_MODEL_NAME, num_labels=NUM_LABELS)
81
  net.load_state_dict(torch.load(path, map_location=DEVICE))
@@ -92,20 +94,19 @@ def tidy(txt: str) -> str:
92
  return txt.strip()
93
 
94
  def infer(segment: str):
95
- """Return (human%, ai%, [top-3 ai model names])."""
96
- inputs = tokeniser(segment, return_tensors="pt", truncation=True,
97
- padding=True).to(DEVICE)
98
  with torch.no_grad():
99
  probs = torch.stack([
100
  torch.softmax(m(**inputs).logits, dim=1) for m in models
101
  ]).mean(dim=0)[0]
102
 
103
- ai_probs = probs.clone(); ai_probs[24] = 0 # null out human idx
104
  ai_score = ai_probs.sum().item() * 100
105
  human_score = 100 - ai_score
106
  top3 = torch.topk(ai_probs, 3).indices.tolist()
107
- top3_names = [LABELS[i] for i in top3]
108
- return human_score, ai_score, top3_names
109
 
110
  # ── Inference + explanation ──────────────────────────────────────────────
111
  @spaces.GPU
@@ -139,7 +140,7 @@ def analyse(text: str):
139
  f"AI-generated {ai_tot/n:.2f}%</span>")
140
  return verdict + "<hr>" + "<br>".join(highlighted)
141
 
142
- # ── Interface ────────────────────────────────────────────────────────────
143
  with gr.Blocks(css=CSS, title="Orify Text Detector") as demo:
144
  gr.Markdown("""
145
  ### Orify Text Detector
 
3
 
4
  β€’ Three ModernBERT-base checkpoints (soft-vote)
5
  β€’ Per-line colour coding, probability tool-tips, top-3 AI model hints
6
+ β€’ Weights auto-downloaded once from the model repo and cached
7
  """
8
 
9
  # ── Imports ──────────────────────────────────────────────────────────────
10
  from pathlib import Path
11
+ import re, os, torch, gradio as gr
12
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
13
  from huggingface_hub import hf_hub_download
14
  import spaces
15
+ import typing # ← fix: use typing.Any
16
 
17
+ # ────────────────── robust torch.compile shim ──────────────────────────
18
  if hasattr(torch, "compile"):
19
+ def _no_compile(model: typing.Any = None, *args, **kwargs):
20
  """
21
+ 1. torch.compile(model, …) β†’ return the model unchanged
22
+ 2. torch.compile(**kw) (decorator) β†’ return a decorator that
23
+ immediately gives back the class/function it decorates
24
  """
25
+ if callable(model): # pattern 1
26
  return model
27
+
28
+ def decorator(fn): # pattern 2
29
  return fn
30
  return decorator
31
 
32
+ torch.compile = _no_compile
33
  os.environ["TORCHINDUCTOR_DISABLED"] = "1"
34
 
35
+ # ── Configuration ────────────────────────────────────────────────────────
36
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
38
+ FILE_MAP = {
39
+ "ensamble_1": "ensamble_1",
40
+ "ensamble_2.bin": "ensamble_2.bin",
41
+ "ensamble_3": "ensamble_3",
42
+ }
43
+
44
  BASE_MODEL_NAME = "answerdotai/ModernBERT-base"
45
  NUM_LABELS = 41
46
 
47
+ LABELS = { # id β†’ friendly label
48
+ 0: "13B", 1: "30B", 2: "65B", 3: "7B", 4: "GLM130B", 5: "bloom_7b",
49
+ 6: "bloomz", 7: "cohere", 8: "davinci", 9: "dolly", 10: "dolly-v2-12b",
50
+ 11: "flan_t5_base", 12: "flan_t5_large", 13: "flan_t5_small",
51
+ 14: "flan_t5_xl", 15: "flan_t5_xxl", 16: "gemma-7b-it",
52
+ 17: "gemma2-9b-it", 18: "gpt-3.5-turbo", 19: "gpt-35", 20: "gpt-4",
 
53
  21: "gpt-4o", 22: "gpt-j", 23: "gpt-neox", 24: "human",
54
  25: "llama3-70b", 26: "llama3-8b", 27: "mixtral-8x7b",
55
  28: "opt-1.3b", 29: "opt-125m", 30: "opt-13b",
 
58
  37: "t0-11b", 38: "t0-3b", 39: "text-davinci-002", 40: "text-davinci-003"
59
  }
60
 
61
+ # ── CSS (unchanged) ──────────────────────────────────────────────────────
62
  CSS = Path(__file__).with_name("style.css").read_text() if Path(__file__).with_name("style.css").exists() else """
63
  :root{--clr-ai:#ff4d4f;--clr-human:#52c41a;--border:2px solid var(--clr-ai);--radius:10px}
64
  body{font-family:'Roboto Mono',monospace;margin:0 auto;max-width:900px;padding:32px}
 
77
  tokeniser = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
78
 
79
  models = []
80
+ for _, path in local_paths.items():
81
  net = AutoModelForSequenceClassification.from_pretrained(
82
  BASE_MODEL_NAME, num_labels=NUM_LABELS)
83
  net.load_state_dict(torch.load(path, map_location=DEVICE))
 
94
  return txt.strip()
95
 
96
  def infer(segment: str):
97
+ """Return (human%, ai%, list of top-3 AI model names)."""
98
+ inputs = tokeniser(segment, return_tensors="pt",
99
+ truncation=True, padding=True).to(DEVICE)
100
  with torch.no_grad():
101
  probs = torch.stack([
102
  torch.softmax(m(**inputs).logits, dim=1) for m in models
103
  ]).mean(dim=0)[0]
104
 
105
+ ai_probs = probs.clone(); ai_probs[24] = 0
106
  ai_score = ai_probs.sum().item() * 100
107
  human_score = 100 - ai_score
108
  top3 = torch.topk(ai_probs, 3).indices.tolist()
109
+ return human_score, ai_score, [LABELS[i] for i in top3]
 
110
 
111
  # ── Inference + explanation ──────────────────────────────────────────────
112
  @spaces.GPU
 
140
  f"AI-generated {ai_tot/n:.2f}%</span>")
141
  return verdict + "<hr>" + "<br>".join(highlighted)
142
 
143
+ # ── Gradio interface ─────────────────────────────────────────────────────
144
  with gr.Blocks(css=CSS, title="Orify Text Detector") as demo:
145
  gr.Markdown("""
146
  ### Orify Text Detector