im2 commited on
Commit
d1888a8
·
1 Parent(s): d1d4583
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -33,7 +33,7 @@ model.eval()
33
  # Gradio preprocessing and prediction pipeline
34
  def predict_digit(image):
35
  # Preprocess the image: resize to 28x28, convert to grayscale, and normalize
36
- image = image.convert('L') # Convert to grayscale
37
  transform = transforms.Compose([
38
  transforms.Resize((28, 28)),
39
  transforms.ToTensor(),
@@ -52,7 +52,7 @@ def predict_digit(image):
52
  # Create Gradio Interface
53
  interface = gr.Interface(
54
  fn=predict_digit,
55
- inputs=gr.Image(source="canvas", tool="editor", type="pil"), # User can draw on a canvas
56
  outputs="text",
57
  title="Digit Recognizer",
58
  description="Draw a digit (0-9) and the model will predict the number!"
 
33
  # Gradio preprocessing and prediction pipeline
34
  def predict_digit(image):
35
  # Preprocess the image: resize to 28x28, convert to grayscale, and normalize
36
+ image = Image.fromarray(image).convert('L') # Convert to grayscale
37
  transform = transforms.Compose([
38
  transforms.Resize((28, 28)),
39
  transforms.ToTensor(),
 
52
  # Create Gradio Interface
53
  interface = gr.Interface(
54
  fn=predict_digit,
55
+ inputs=gr.Sketchpad(shape=(28, 28)), # Sketchpad for users to draw
56
  outputs="text",
57
  title="Digit Recognizer",
58
  description="Draw a digit (0-9) and the model will predict the number!"