Spaces:
Sleeping
Sleeping
File size: 3,154 Bytes
0ceff0a a821828 fd4d84b 0ceff0a b60cbe3 991fbeb 7d87e8c d2cbc08 d3a4db8 7092966 d2cbc08 d3a4db8 7092966 d2cbc08 be368a6 3cb2b69 d2cbc08 a821828 991fbeb 0ceff0a 75f06df 991fbeb b60cbe3 991fbeb 93c87bc 991fbeb 074b0e2 991fbeb 7c583bd 7d87e8c 7c583bd 75f06df 70d4bb7 03cffea c63c908 b60cbe3 9afc753 991fbeb 03cffea 70d4bb7 be74628 70d4bb7 03cffea d589927 7e86f8c a821828 70d4bb7 b60cbe3 70d4bb7 bd01473 a821828 bd01473 7d65efc bd01473 93ac46b bc847c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import gradio as gr
from PIL import Image
from vit_model_test import CustomModel
from vit_Training import Custom_VIT_Model
# Initialize the model
model = CustomModel()
model_training = Custom_VIT_Model()
custom_css = """
<style>
#submit_btn {
background-color: #4CAF50; /* Green */
color: white;
width:47%;
margin-right: 3%;
}
#feedback_btn {
background-color: #f44336; /* Red */
color: white;
width:47%;
margin-left: 3%;
}
#clear_btn {
background-color: #2196F3; /* Blue */
color: white;
width: 25%;
float: left;
}
</style>
"""
# Store the last prediction result
last_prediction = None
def predict(image: Image.Image):
global last_prediction
label, confidence = model.predict(image)
result = "AI image" if label == 1 else "Real image"
last_prediction = (image, label) # Store the image and label for feedback
return result, f"Confidence: {confidence:.2f}%"
def report_feedback():
global last_prediction
if last_prediction is not None:
image, predicted_label = last_prediction
correct_label = 1 if predicted_label == 0 else 0 # Invert the label
print(f"Reporting feedback: predicted_label={predicted_label}, correct_label={correct_label}") # Debugging line
try:
model_training.add_data(image, correct_label) # Pass the incorrect prediction to the model
print("Feedback recorded successfully.") # Debugging line
return "Feedback recorded. Thank you!"
except Exception as e:
print(f"Error recording feedback: {e}") # Debugging line
return f"Error recording feedback: {e}"
else:
print("No prediction available to report.") # Debugging line
return "No prediction available to report."
# Define the Gradio blocks
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
gr.Markdown("### Classify image of art as real or AI generated")
image_input = gr.Image(type="pil", label="Upload Image", height=365)
# Create a row for prediction and confidence outputs
with gr.Row():
prediction_output = gr.Textbox(label="Prediction", interactive=False)
confidence_output = gr.Textbox(label="Confidence", interactive=False)
# Create a row for feedback_output
with gr.Row():
feedback_output = gr.Textbox(label="Feedback Status", interactive=False, scale=0,min_width=730)
# Buttons
with gr.Row():
submit_btn = gr.Button("Submit", variant="primary", elem_id="submit_btn")
feedback_btn = gr.Button("The model was wrong", variant="secondary", elem_id="feedback_btn")
gr.Markdown("<br>")
clear_btn = gr.Button("Clear", elem_id="clear_btn")
submit_btn.click(predict, inputs=image_input, outputs=[prediction_output, confidence_output])
feedback_btn.click(report_feedback, outputs=feedback_output)
# Clear button
def clear_all():
return None, "", "", ""
clear_btn.click(clear_all, outputs=[image_input, prediction_output, confidence_output, feedback_output])
if __name__ == "__main__":
demo.launch(share=True) |