File size: 1,959 Bytes
cabc13e
 
 
5223af2
 
cabc13e
3b9f116
cabc13e
3b9f116
 
cabc13e
 
3b9f116
cabc13e
 
3b9f116
28f1a2b
 
a12dbb7
cabc13e
3b9f116
cabc13e
3b9f116
 
 
 
 
 
 
 
 
 
9646a7d
 
 
3b9f116
9646a7d
3b9f116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa4d0f7
 
 
 
 
 
b16ee30
fa4d0f7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import pandas as pd
import numpy as np
import tensorflow as tf
from keras.models import load_model
import gradio as gr

# Define the classes and labels
classes = [
    'car', 'house', 'wine bottle', 'chair', 'table',
    'tree', 'camera', 'fish', 'rain', 'clock', 'hat'
]

labels = {name: index for index, name in enumerate(classes)}
num_classes = len(classes)

# Load the model


model = load_model('sketch_recogination_model_cnn.h5')

# Predict function for the interface
def predict_fn(image):
    """
    Predict the class of a drawn image.

    Args:
        image: The input image drawn by the user.

    Returns:
        The predicted class name.
    """
    try:
        # Extract the image data from the input dictionary
        image_data = image['image'] if isinstance(image, dict) else image

        # Preprocessing the image
        resized_image = tf.image.resize(image_data, (28, 28))  # Resize to (28, 28)
        grayscale_image = tf.image.rgb_to_grayscale(resized_image)  # Convert to grayscale
        image_array = np.array(grayscale_image) / 255.0  # Normalize the image

        # Prepare image for model input
        image_array = image_array.reshape(1, 28, 28, 1)  # Add batch dimension
        predictions = model.predict(image_array).reshape(num_classes)  # Reshape to 1D

        # Get the predicted class index
        predicted_index = tf.argmax(predictions).numpy()  # Convert to numpy
        class_name = classes[predicted_index]  # Get class name

        return class_name
    except Exception as e:
        return f"Error in prediction: {str(e)}"

# Gradio application interface
gr.Interface(
    fn=predict_fn,
    inputs="paint",
    outputs="label",
    title="DoodleDecoder",
    description="Draw something from: Car, House, Wine bottle, Chair, Table, Tree, Camera, Fish, Rain, Clock, Hat",
    interpretation='default',  # Add the interpretation parameter here
    article="Draw large with thick stroke."
).launch()