Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
# app.py
|
2 |
|
3 |
-
from huggingface_hub import snapshot_download
|
4 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
5 |
import torch
|
6 |
import gradio as gr
|
@@ -9,9 +8,11 @@ from dataclasses import dataclass
|
|
9 |
from pathlib import Path
|
10 |
import spaces
|
11 |
|
|
|
12 |
@dataclass
|
13 |
class SymbolicConfig:
|
14 |
repo_id: str = "AbstractPhil/bert-beatrix-2048"
|
|
|
15 |
symbolic_roles: list = (
|
16 |
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
|
17 |
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
|
@@ -22,14 +23,15 @@ class SymbolicConfig:
|
|
22 |
)
|
23 |
|
24 |
config = SymbolicConfig()
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
29 |
|
30 |
MASK_TOKEN = tokenizer.mask_token or "[MASK]"
|
31 |
|
32 |
-
@spaces.GPU
|
33 |
def mask_and_predict(text: str, selected_roles: list[str]):
|
34 |
results = []
|
35 |
masked_text = text
|
|
|
1 |
# app.py
|
2 |
|
|
|
3 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
4 |
import torch
|
5 |
import gradio as gr
|
|
|
8 |
from pathlib import Path
|
9 |
import spaces
|
10 |
|
11 |
+
@spaces.GPU
|
12 |
@dataclass
|
13 |
class SymbolicConfig:
|
14 |
repo_id: str = "AbstractPhil/bert-beatrix-2048"
|
15 |
+
revision: str = "main"
|
16 |
symbolic_roles: list = (
|
17 |
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
|
18 |
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
|
|
|
23 |
)
|
24 |
|
25 |
config = SymbolicConfig()
|
26 |
+
tokenizer = AutoTokenizer.from_pretrained(config.repo_id, revision=config.revision)
|
27 |
+
model = AutoModelForMaskedLM.from_pretrained(
|
28 |
+
config.repo_id,
|
29 |
+
revision=config.revision,
|
30 |
+
trust_remote_code=True
|
31 |
+
).eval().cuda()
|
32 |
|
33 |
MASK_TOKEN = tokenizer.mask_token or "[MASK]"
|
34 |
|
|
|
35 |
def mask_and_predict(text: str, selected_roles: list[str]):
|
36 |
results = []
|
37 |
masked_text = text
|