titipata commited on
Commit
51fd458
·
verified ·
1 Parent(s): ae94582

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -19
app.py CHANGED
@@ -54,31 +54,25 @@ def predict(img):
54
  {label: confidence, label: confidence, ...}
55
  """
56
 
57
- try:
58
- if img is None or not isinstance(img, dict) or 'image' not in img:
59
- return {"Error": 1.0}
60
-
61
- img_data = img['composite']
62
- img_gray = Image.fromarray(img_data).convert('L').resize((28, 28))
63
- img_tensor = transforms.ToTensor()(img_gray).unsqueeze(0)
64
-
65
- # Make prediction
66
- with torch.no_grad():
67
- probs = model(img_tensor).softmax(dim=1).squeeze()
68
-
69
- probs, indices = torch.topk(probs, 5) # select top 5
70
- probs, indices = probs.tolist(), indices.tolist() # transform to list
71
- return {LABELS[i]: float(v) for i, v in zip(indices, probs)}
72
- except Exception as e:
73
- print(f"Error in prediction: {str(e)}")
74
- return {"Error": 1.0}
75
 
76
 
77
  demo = gr.Interface(
78
  fn=predict,
79
  inputs=gr.Sketchpad(
80
  label="Draw Here",
81
- brush=gr.Brush(default_size=20, default_color="#FFFFFF", colors=["#FFFFFF"]),
82
  image_mode="L",
83
  layers=False
84
  ),
 
54
  {label: confidence, label: confidence, ...}
55
  """
56
 
57
+ img_data = img['composite']
58
+ print(img_data.sum())
59
+ img_gray = Image.fromarray(img_data).convert('L').resize((28, 28))
60
+ img_tensor = transforms.ToTensor()(img_gray).unsqueeze(0)
61
+
62
+ # Make prediction
63
+ with torch.no_grad():
64
+ probs = model(img_tensor).softmax(dim=1).squeeze()
65
+
66
+ probs, indices = torch.topk(probs, 5) # select top 5
67
+ probs, indices = probs.tolist(), indices.tolist() # transform to list
68
+ return {LABELS[i]: float(v) for i, v in zip(indices, probs)}
 
 
 
 
 
 
69
 
70
 
71
  demo = gr.Interface(
72
  fn=predict,
73
  inputs=gr.Sketchpad(
74
  label="Draw Here",
75
+ brush=gr.Brush(default_size=20, default_color="#000000", colors=["#000000"]),
76
  image_mode="L",
77
  layers=False
78
  ),