benjaminStreltzin commited on
Commit
991fbeb
·
verified ·
1 Parent(s): 83b587f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -12
app.py CHANGED
@@ -1,26 +1,43 @@
1
  import gradio as gr
2
  from PIL import Image
3
- from vit_model_test import CustomModel
4
 
5
  # Initialize the model
6
- model = CustomModel()
 
 
 
7
 
8
  def predict(image: Image.Image):
9
- label, confidence = model.predict(image)
10
- result = "AI image" if label == 1 else "Real image"
11
- return result, f"Confidence: {confidence:.2f}%"
 
 
12
 
 
 
 
 
 
 
 
13
 
14
- # Define the Gradio interface
15
  demo = gr.Interface(
16
- fn=predict,
17
- inputs=gr.Image(type="pil"),
18
- outputs=[gr.Textbox(), gr.Textbox()],
19
- title="Vision Transformer Model",
20
- description="Upload an image to classify it using the Vision Transformer model.",
21
- theme = gr.themes.Soft()
22
  )
23
 
 
 
 
 
24
  # Launch the Gradio interface
25
  if __name__ == "__main__":
26
  demo.launch(share=True)
 
 
1
  import gradio as gr
2
  from PIL import Image
3
+ from vit_model_test import Custom_VIT_Model # Ensure you import the correct class
4
 
5
  # Initialize the model
6
+ model = Custom_VIT_Model()
7
+
8
+ # Variable to store the last prediction result
9
+ last_prediction = None
10
 
11
  def predict(image: Image.Image):
12
+ global last_prediction
13
+ label, confidence = model.predict(image)
14
+ result = "AI image" if label == 1 else "Real image"
15
+ last_prediction = (image, label) # Store the image and prediction label
16
+ return result, f"Confidence: {confidence:.2f}%"
17
 
18
+ def report_feedback():
19
+ if last_prediction is not None:
20
+ image, predicted_label = last_prediction
21
+ correct_label = 1 if predicted_label == 0 else 0 # Invert the label
22
+ model.add_data(image, correct_label) # Add incorrect prediction to model
23
+ return "Feedback recorded. Thank you!"
24
+ return "No prediction available to report."
25
 
26
+ # Define the Gradio interface for prediction
27
  demo = gr.Interface(
28
+ fn=predict,
29
+ inputs=gr.Image(type="pil"),
30
+ outputs=[gr.Textbox(label="Prediction"), gr.Textbox(label="Confidence")],
31
+ title="Vision Transformer Model",
32
+ description="Upload an image to classify it using the Vision Transformer model.",
33
+ theme=gr.themes.Soft()
34
  )
35
 
36
+ # Define the feedback button
37
+ feedback_button = gr.Button("The model was wrong")
38
+ feedback_button.click(report_feedback)
39
+
40
  # Launch the Gradio interface
41
  if __name__ == "__main__":
42
  demo.launch(share=True)
43
+ feedback_button.launch()