szili2011 commited on
Commit
ce971f1
·
verified ·
1 Parent(s): f77bf53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -53
app.py CHANGED
@@ -1,69 +1,53 @@
1
  import os
2
  import numpy as np
3
  import cv2
4
- from fastapi import FastAPI
5
- from fastapi.responses import HTMLResponse
6
- from pydantic import BaseModel
7
  from tensorflow import keras
8
- from starlette.responses import FileResponse
9
- from starlette.middleware.cors import CORSMiddleware
10
-
11
- # Define the FastAPI app
12
- app = FastAPI()
13
-
14
- # Add CORS middleware
15
- app.add_middleware(
16
- CORSMiddleware,
17
- allow_origins=["*"],
18
- allow_credentials=True,
19
- allow_methods=["*"],
20
- allow_headers=["*"],
21
- )
22
 
23
  # Load the model
24
  model_path = 'sketch2draw_model.h5' # Update with your model path
25
  model = keras.models.load_model(model_path)
26
 
27
- # Define the request body
28
- class TextureRequest(BaseModel):
29
- texture: str
30
-
31
  # Load class names for predictions
32
  class_names = ['grass', 'dirt', 'wood', 'water', 'sky', 'clouds']
33
 
34
- @app.get("/", response_class=HTMLResponse)
35
- async def read_root():
36
- return """
37
- <html>
38
- <head>
39
- <title>Sketch to Draw</title>
40
- </head>
41
- <body>
42
- <h1>Sketch to Draw Model</h1>
43
- <form action="/predict" method="post">
44
- <input type="text" name="texture" placeholder="Enter texture name (grass, dirt, wood, water, sky, clouds)">
45
- <button type="submit">Predict</button>
46
- </form>
47
- </body>
48
- </html>
49
- """
50
 
51
- @app.post("/predict")
52
- async def predict_texture(request: TextureRequest):
53
- texture_name = request.texture
54
-
55
- # Process the input texture (you can modify this part)
56
- # Example: Load image and preprocess it
57
- # image = cv2.imread(f'path_to_your_texture_images/{texture_name}.png')
58
- # image = cv2.resize(image, (128, 128)) # Resize as per your model's input
59
- # image = np.expand_dims(image, axis=0) / 255.0 # Normalize if needed
60
 
61
  # Make prediction
62
- predictions = model.predict(image) # Add your processed image here
63
  predicted_class = class_names[np.argmax(predictions)]
64
-
65
- return {"predicted_texture": predicted_class}
66
-
67
- if __name__ == "__main__":
68
- import uvicorn
69
- uvicorn.run(app, host="0.0.0.0", port=8080)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import numpy as np
3
  import cv2
4
+ import base64
5
+ import gradio as gr
 
6
  from tensorflow import keras
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # Load the model
9
  model_path = 'sketch2draw_model.h5' # Update with your model path
10
  model = keras.models.load_model(model_path)
11
 
 
 
 
 
12
  # Load class names for predictions
13
  class_names = ['grass', 'dirt', 'wood', 'water', 'sky', 'clouds']
14
 
15
+ def predict(image):
16
+ # Decode the image from base64
17
+ image_data = np.frombuffer(base64.b64decode(image.split(",")[1]), np.uint8)
18
+ image = cv2.imdecode(image_data, cv2.IMREAD_COLOR)
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ # Resize and normalize the image
21
+ image = cv2.resize(image, (128, 128)) # Resize as per your model's input
22
+ image = np.expand_dims(image, axis=0) / 255.0 # Normalize if needed
 
 
 
 
 
 
23
 
24
  # Make prediction
25
+ predictions = model.predict(image)
26
  predicted_class = class_names[np.argmax(predictions)]
27
+
28
+ return predicted_class
29
+
30
+ # Create Gradio interface
31
+ with gr.Blocks() as demo:
32
+ gr.Markdown("<h1>Sketch to Draw Model</h1>")
33
+
34
+ with gr.Row():
35
+ with gr.Column():
36
+ canvas = gr.Sketchpad(shape=(400, 400), label="Draw Here", tool="brush")
37
+ brush_color = gr.ColorPicker(value="black", label="Brush Color")
38
+ clear_btn = gr.Button("Clear")
39
+
40
+ with gr.Column():
41
+ predict_btn = gr.Button("Predict")
42
+ output_label = gr.Textbox(label="Predicted Texture")
43
+
44
+ # Define the actions for buttons
45
+ def clear_canvas():
46
+ return np.zeros((400, 400, 3), dtype=np.uint8)
47
+
48
+ clear_btn.click(fn=clear_canvas, inputs=None, outputs=canvas)
49
+
50
+ predict_btn.click(fn=predict, inputs=canvas, outputs=output_label)
51
+
52
+ # Launch the Gradio app
53
+ demo.launch()