LilithHu commited on
Commit
2cb5c95
·
verified ·
1 Parent(s): 3662b93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -22,15 +22,15 @@ def classify(text):
22
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
23
  with torch.no_grad():
24
  outputs = model(**inputs)
25
- probs = torch.softmax(outputs.logits, dim=1)[0] # 取第一个样本的概率向量
26
- probs = torch.clamp(probs, max=0.95) # 限制最大置信度为 95%
27
- result = "🧠 预测 / Prediction:\n"
28
- for i, label in enumerate(labels):
29
- percent = round(probs[i].item() * 100, 2)
30
- result += f"{label}: {percent}%\n"
31
  return result
32
 
33
 
 
34
  # Gradio 界面
35
  interface = gr.Interface(
36
  fn=classify,
 
22
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
23
  with torch.no_grad():
24
  outputs = model(**inputs)
25
+ probs = torch.softmax(outputs.logits, dim=1)[0]
26
+ pred = torch.argmax(probs).item()
27
+ confidence = min(probs[pred].item(), 0.95) # 限制置信度最大为95%
28
+ percent = round(confidence * 100, 2)
29
+ result = f"🧠 预测 / Prediction:\n{labels[pred]}: {percent}%"
 
30
  return result
31
 
32
 
33
+
34
  # Gradio 界面
35
  interface = gr.Interface(
36
  fn=classify,