benjaminStreltzin commited on
Commit
b60cbe3
·
verified ·
1 Parent(s): fd4d84b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -24
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import gradio as gr
2
  from PIL import Image
3
- from vit_model_test import CustomModel # Ensure you import the correct class
4
  from vit_Training import Custom_VIT_Model
5
 
6
  # Initialize the model
7
- model = Custom_VIT_Model()
8
 
9
  # Variable to store the last prediction result
10
  last_prediction = None
@@ -13,7 +13,7 @@ def predict(image: Image.Image):
13
  global last_prediction
14
  label, confidence = model.predict(image)
15
  result = "AI image" if label == 1 else "Real image"
16
- last_prediction = (image, label) # Store the image and prediction label
17
  return result, f"Confidence: {confidence:.2f}%"
18
 
19
  def report_feedback():
@@ -21,29 +21,26 @@ def report_feedback():
21
  if last_prediction is not None:
22
  image, predicted_label = last_prediction
23
  correct_label = 1 if predicted_label == 0 else 0 # Invert the label
24
- model.add_data(image, correct_label) # Add incorrect prediction to model
25
  return "Feedback recorded. Thank you!"
26
  return "No prediction available to report."
27
 
28
- # Define the Gradio interface for prediction and feedback
29
- def main():
30
- with gr.Blocks() as demo:
31
- gr.Markdown("### Vision Transformer Model")
32
- gr.Markdown("Upload an image to classify it using the Vision Transformer model.")
33
-
34
- image_input = gr.Image(type="pil", label="Upload Image")
35
- prediction_output = gr.Textbox(label="Prediction", interactive=False)
36
- confidence_output = gr.Textbox(label="Confidence", interactive=False)
37
- feedback_output = gr.Textbox(label="Feedback Status", interactive=False)
38
-
39
- submit_btn = gr.Button("Submit")
40
- feedback_btn = gr.Button("The model was wrong")
41
-
42
- submit_btn.click(predict, inputs=image_input, outputs=[prediction_output, confidence_output])
43
- feedback_btn.click(report_feedback, outputs=feedback_output)
44
-
45
- # Launch the Gradio interface
46
- demo.launch(share=True)
47
 
 
 
 
 
 
 
 
48
  if __name__ == "__main__":
49
- main()
 
1
  import gradio as gr
2
  from PIL import Image
3
+ from vit_model_test import CustomModel # Ensure this is the correct import for your model
4
  from vit_Training import Custom_VIT_Model
5
 
6
  # Initialize the model
7
+ model = CustomModel()
8
 
9
  # Variable to store the last prediction result
10
  last_prediction = None
 
13
  global last_prediction
14
  label, confidence = model.predict(image)
15
  result = "AI image" if label == 1 else "Real image"
16
+ last_prediction = (image, label) # Store the image and label for feedback
17
  return result, f"Confidence: {confidence:.2f}%"
18
 
19
  def report_feedback():
 
21
  if last_prediction is not None:
22
  image, predicted_label = last_prediction
23
  correct_label = 1 if predicted_label == 0 else 0 # Invert the label
24
+ model.add_data(image, correct_label) # Pass the incorrect prediction to the model
25
  return "Feedback recorded. Thank you!"
26
  return "No prediction available to report."
27
 
28
+ # Define the Gradio interface
29
+ with gr.Blocks() as demo:
30
+ gr.Markdown("### Vision Transformer Model")
31
+ gr.Markdown("Upload an image to classify it using the Vision Transformer model.")
32
+
33
+ image_input = gr.Image(type="pil", label="Upload Image")
34
+ prediction_output = gr.Textbox(label="Prediction", interactive=False)
35
+ confidence_output = gr.Textbox(label="Confidence", interactive=False)
36
+ feedback_output = gr.Textbox(label="Feedback Status", interactive=False)
 
 
 
 
 
 
 
 
 
 
37
 
38
+ submit_btn = gr.Button("Submit")
39
+ feedback_btn = gr.Button("The model was wrong")
40
+
41
+ submit_btn.click(predict, inputs=image_input, outputs=[prediction_output, confidence_output])
42
+ feedback_btn.click(report_feedback, outputs=feedback_output)
43
+
44
+ # Launch the Gradio interface
45
  if __name__ == "__main__":
46
+ demo.launch(share=True)