File size: 3,380 Bytes
0ceff0a
 
b60cbe3
fd4d84b
0ceff0a
 
b60cbe3
991fbeb
7d87e8c
 
d2cbc08
 
 
 
 
 
d3a4db8
7092966
d2cbc08
 
 
 
d3a4db8
7092966
d2cbc08
 
 
 
be368a6
3cb2b69
d2cbc08
 
 
 
991fbeb
 
0ceff0a
75f06df
991fbeb
 
 
b60cbe3
991fbeb
93c87bc
991fbeb
074b0e2
991fbeb
 
 
7c583bd
 
7d87e8c
7c583bd
 
 
 
 
 
 
 
75f06df
b60cbe3
03cffea
b60cbe3
 
 
 
991fbeb
03cffea
 
 
 
 
 
be74628
4eb52b0
03cffea
 
d589927
 
 
7e86f8c
03cffea
be74628
b60cbe3
 
 
 
bd01473
 
7d65efc
bd01473
7d65efc
bd01473
93ac46b
bc847c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
from PIL import Image
from vit_model_test import CustomModel  # Ensure this is the correct import for your model
from vit_Training import Custom_VIT_Model

# Initialize the model
model = CustomModel()

model_training = Custom_VIT_Model()


custom_css = """
<style>
#submit_btn {
    background-color: #4CAF50; /* Green */
    color: white;
    width:47%;
    margin-right: 3%;
}
#feedback_btn {
    background-color: #f44336; /* Red */
    color: white;
    width:47%;
    margin-left: 3%;
}
#clear_btn {
    background-color: #2196F3; /* Blue */
    color: white;
    width: 25%;
    float: left;
}
</style>
"""

# Variable to store the last prediction result
last_prediction = None

def predict(image: Image.Image):
    global last_prediction
    label, confidence = model.predict(image)
    result = "AI image" if label == 1 else "Real image"
    last_prediction = (image, label)  # Store the image and label for feedback
    return result, f"Confidence: {confidence:.2f}%"

def report_feedback():
    global last_prediction
    if last_prediction is not None:
        image, predicted_label = last_prediction
        correct_label = 1 if predicted_label == 0 else 0  # Invert the label
        print(f"Reporting feedback: predicted_label={predicted_label}, correct_label={correct_label}")  # Debugging line
        try:
            model_training.add_data(image, correct_label)  # Pass the incorrect prediction to the model
            print("Feedback recorded successfully.")  # Debugging line
            return "Feedback recorded. Thank you!"
        except Exception as e:
            print(f"Error recording feedback: {e}")  # Debugging line
            return f"Error recording feedback: {e}"
    else:
        print("No prediction available to report.")  # Debugging line
        return "No prediction available to report."

# Define the Gradio interface
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
    gr.Markdown("### Vision Transformer Model")
    gr.Markdown("Upload an image to classify it using the Vision Transformer model.")
    
    image_input = gr.Image(type="pil", label="Upload Image")

    # Create a row for prediction and confidence outputs
    with gr.Row():
        prediction_output = gr.Textbox(label="Prediction", interactive=False)
        confidence_output = gr.Textbox(label="Confidence", interactive=False)

    # Create a row for feedback output
    with gr.Row():
        feedback_output = gr.Textbox(label="Feedback Status", interactive=False, scale=0,min_width=700)  # Add class to feedback output
    
    # Buttons
    with gr.Row():
        submit_btn = gr.Button("Submit", variant="primary", elem_id="submit_btn")
        feedback_btn = gr.Button("The model was wrong", variant="secondary", elem_id="feedback_btn")
        
    gr.Markdown("<br>")  # Adding space below buttons
    clear_btn = gr.Button("Clear", elem_id="clear_btn")  # No need for elem_classes here

    submit_btn.click(predict, inputs=image_input, outputs=[prediction_output, confidence_output])
    feedback_btn.click(report_feedback, outputs=feedback_output)

    # Clear button logic
    def clear_all():
        return None, "", "", ""  # Clear all outputs

    clear_btn.click(clear_all, outputs=[image_input, prediction_output, confidence_output, feedback_output])

if __name__ == "__main__":
    demo.launch(share=True)