Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import cv2 | |
import base64 | |
import gradio as gr | |
from tensorflow import keras | |
# Load the model | |
model_path = 'sketch2draw_model.h5' # Update with your model path | |
model = keras.models.load_model(model_path) | |
# Load class names for predictions | |
class_names = ['grass', 'dirt', 'wood', 'water', 'sky', 'clouds'] | |
def predict(image): | |
# Decode the image from base64 | |
image_data = np.frombuffer(base64.b64decode(image.split(",")[1]), np.uint8) | |
image = cv2.imdecode(image_data, cv2.IMREAD_COLOR) | |
# Resize and normalize the image | |
image = cv2.resize(image, (128, 128)) # Resize as per your model's input | |
image = np.expand_dims(image, axis=0) / 255.0 # Normalize if needed | |
# Make prediction | |
predictions = model.predict(image) | |
predicted_class = class_names[np.argmax(predictions)] | |
return predicted_class | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("<h1>Sketch to Draw Model</h1>") | |
with gr.Row(): | |
with gr.Column(): | |
canvas = gr.Sketchpad(shape=(400, 400), label="Draw Here", tool="brush") | |
brush_color = gr.ColorPicker(value="black", label="Brush Color") | |
clear_btn = gr.Button("Clear") | |
with gr.Column(): | |
predict_btn = gr.Button("Predict") | |
output_label = gr.Textbox(label="Predicted Texture") | |
# Define the actions for buttons | |
def clear_canvas(): | |
return np.zeros((400, 400, 3), dtype=np.uint8) | |
clear_btn.click(fn=clear_canvas, inputs=None, outputs=canvas) | |
predict_btn.click(fn=predict, inputs=canvas, outputs=output_label) | |
# Launch the Gradio app | |
demo.launch() | |