LilithHu commited on
Commit
1e3779f
·
verified ·
1 Parent(s): 7635550

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -2,22 +2,30 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
 
5
- # 加载模型
6
  model_name = "LilithHu/mbert-manipulative-detector"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
 
10
- # 二分类标签(非操纵性是0,操纵性是1)
 
 
 
 
 
 
 
11
  labels = ["Non-manipulative / 非操纵性", "Manipulative / 操纵性"]
12
 
 
13
  def classify(text):
14
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
15
  with torch.no_grad():
16
  outputs = model(**inputs)
17
  probs = torch.softmax(outputs.logits, dim=1)
18
  pred = torch.argmax(probs, dim=1).item()
19
  confidence = probs[0][pred].item()
20
- return f"🧠 预测 / Prediction: {labels[pred]}\n"
21
 
22
  # Gradio 界面
23
  interface = gr.Interface(
 
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
 
5
+ # 加载模型和 tokenizer
6
  model_name = "LilithHu/mbert-manipulative-detector"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
 
10
+ # 设置为评估模式
11
+ model.eval()
12
+
13
+ # 设置运行设备
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ model.to(device)
16
+
17
+ # 标签名
18
  labels = ["Non-manipulative / 非操纵性", "Manipulative / 操纵性"]
19
 
20
+ # 推理函数
21
  def classify(text):
22
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
23
  with torch.no_grad():
24
  outputs = model(**inputs)
25
  probs = torch.softmax(outputs.logits, dim=1)
26
  pred = torch.argmax(probs, dim=1).item()
27
  confidence = probs[0][pred].item()
28
+ return f"🧠 预测 / Prediction: {labels[pred]}\n🔢 置信度 / Confidence: {confidence*100:.2f}%"
29
 
30
  # Gradio 界面
31
  interface = gr.Interface(