benjaminStreltzin commited on
Commit
f5dd171
·
verified ·
1 Parent(s): 074b0e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -16
app.py CHANGED
@@ -24,21 +24,25 @@ def report_feedback():
24
  return "Feedback recorded. Thank you!"
25
  return "No prediction available to report."
26
 
27
- # Define the Gradio interface for prediction
28
- with gr.Blocks() as demo:
29
- gr.Markdown("### Vision Transformer Model")
30
- gr.Markdown("Upload an image to classify it using the Vision Transformer model.")
31
-
32
- image_input = gr.Image(type="pil", label="Upload Image")
33
- prediction_output = gr.Textbox(label="Prediction")
34
- confidence_output = gr.Textbox(label="Confidence")
35
-
36
- submit_btn = gr.Button("Submit")
37
- feedback_btn = gr.Button("The model was wrong")
38
-
39
- submit_btn.click(predict, inputs=image_input, outputs=[prediction_output, confidence_output])
40
- feedback_btn.click(report_feedback)
 
 
 
 
 
41
 
42
- # Launch the Gradio interface
43
  if __name__ == "__main__":
44
- demo.launch(share=True)
 
24
  return "Feedback recorded. Thank you!"
25
  return "No prediction available to report."
26
 
27
+ # Define the Gradio interface for prediction and feedback
28
+ def main():
29
+ with gr.Blocks() as demo:
30
+ gr.Markdown("### Vision Transformer Model")
31
+ gr.Markdown("Upload an image to classify it using the Vision Transformer model.")
32
+
33
+ image_input = gr.Image(type="pil", label="Upload Image")
34
+ prediction_output = gr.Textbox(label="Prediction", interactive=False)
35
+ confidence_output = gr.Textbox(label="Confidence", interactive=False)
36
+ feedback_output = gr.Textbox(label="Feedback Status", interactive=False)
37
+
38
+ submit_btn = gr.Button("Submit")
39
+ feedback_btn = gr.Button("The model was wrong")
40
+
41
+ submit_btn.click(predict, inputs=image_input, outputs=[prediction_output, confidence_output])
42
+ feedback_btn.click(report_feedback, outputs=feedback_output)
43
+
44
+ # Launch the Gradio interface
45
+ demo.launch(share=True)
46
 
 
47
  if __name__ == "__main__":
48
+ main()