LilithHu commited on
Commit
fde1d98
·
verified ·
1 Parent(s): 1e3779f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -21,11 +21,15 @@ labels = ["Non-manipulative / 非操纵性", "Manipulative / 操纵性"]
21
  def classify(text):
22
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
23
  with torch.no_grad():
24
- outputs = model(**inputs)
25
- probs = torch.softmax(outputs.logits, dim=1)
26
- pred = torch.argmax(probs, dim=1).item()
27
- confidence = probs[0][pred].item()
28
- return f"🧠 预测 / Prediction: {labels[pred]}\n🔢 置信度 / Confidence: {confidence*100:.2f}%"
 
 
 
 
29
 
30
  # Gradio 界面
31
  interface = gr.Interface(
 
21
  def classify(text):
22
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
23
  with torch.no_grad():
24
+ outputs = model(**inputs)
25
+ probs = torch.softmax(outputs.logits, dim=1)[0] # 取第一个样本的概率向量
26
+ probs = torch.clamp(probs, max=0.95) # 限制最大置信度为 95%
27
+ result = "🧠 预测 / Prediction:\n"
28
+ for i, label in enumerate(labels):
29
+ percent = round(probs[i].item() * 100, 2)
30
+ result += f"{label}: {percent}%\n"
31
+ return result
32
+
33
 
34
  # Gradio 界面
35
  interface = gr.Interface(