Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,22 +2,30 @@ import gradio as gr
|
|
2 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
3 |
import torch
|
4 |
|
5 |
-
#
|
6 |
model_name = "LilithHu/mbert-manipulative-detector"
|
7 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
8 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
9 |
|
10 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
labels = ["Non-manipulative / 非操纵性", "Manipulative / 操纵性"]
|
12 |
|
|
|
13 |
def classify(text):
|
14 |
-
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
15 |
with torch.no_grad():
|
16 |
outputs = model(**inputs)
|
17 |
probs = torch.softmax(outputs.logits, dim=1)
|
18 |
pred = torch.argmax(probs, dim=1).item()
|
19 |
confidence = probs[0][pred].item()
|
20 |
-
return f"🧠 预测 / Prediction: {labels[pred]}\n"
|
21 |
|
22 |
# Gradio 界面
|
23 |
interface = gr.Interface(
|
|
|
2 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
3 |
import torch
|
4 |
|
5 |
+
# 加载模型和 tokenizer
|
6 |
model_name = "LilithHu/mbert-manipulative-detector"
|
7 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
8 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
9 |
|
10 |
+
# 设置为评估模式
|
11 |
+
model.eval()
|
12 |
+
|
13 |
+
# 设置运行设备
|
14 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
+
model.to(device)
|
16 |
+
|
17 |
+
# 标签名
|
18 |
labels = ["Non-manipulative / 非操纵性", "Manipulative / 操纵性"]
|
19 |
|
20 |
+
# 推理函数
|
21 |
def classify(text):
|
22 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
|
23 |
with torch.no_grad():
|
24 |
outputs = model(**inputs)
|
25 |
probs = torch.softmax(outputs.logits, dim=1)
|
26 |
pred = torch.argmax(probs, dim=1).item()
|
27 |
confidence = probs[0][pred].item()
|
28 |
+
return f"🧠 预测 / Prediction: {labels[pred]}\n🔢 置信度 / Confidence: {confidence*100:.2f}%"
|
29 |
|
30 |
# Gradio 界面
|
31 |
interface = gr.Interface(
|