import gradio as gr from PIL import Image from vit_model_test import CustomModel from vit_Training import Custom_VIT_Model custom_css = """ """ # Initialize the model model = CustomModel() model_training = Custom_VIT_Model() # Store the last prediction result last_prediction = None def predict(image: Image.Image): global last_prediction label, confidence = model.predict(image) result = "AI image" if label == 1 else "Real image" last_prediction = (image, label) # Store the image and label for feedback return result, f"Confidence: {confidence:.2f}%" def report_feedback(): global last_prediction if last_prediction is not None: image, predicted_label = last_prediction correct_label = 1 if predicted_label == 0 else 0 # Invert the label print(f"Reporting feedback: predicted_label={predicted_label}, correct_label={correct_label}") # Debugging line try: model_training.add_data(image, correct_label) # Pass the incorrect prediction to the model print("Feedback recorded successfully.") # Debugging line return "Feedback recorded. Thank you!" except Exception as e: print(f"Error recording feedback: {e}") # Debugging line return f"Error recording feedback: {e}" else: print("No prediction available to report.") # Debugging line return "No prediction available to report." # Define the Gradio blocks with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: gr.Markdown("### Classify image of art as real or AI generated") image_input = gr.Image(type="pil", label="Upload Image", height=365) # Create a row for prediction and confidence outputs with gr.Row(): prediction_output = gr.Textbox(label="Prediction", interactive=False) confidence_output = gr.Textbox(label="Confidence", interactive=False) # Create a row for feedback_output with gr.Row(): feedback_output = gr.Textbox(label="Feedback Status", interactive=False, scale=0,min_width=730) # Buttons with gr.Row(): submit_btn = gr.Button("Submit", variant="primary", elem_id="submit_btn") feedback_btn = gr.Button("The model was wrong", variant="secondary", elem_id="feedback_btn") gr.Markdown("
") clear_btn = gr.Button("Clear", elem_id="clear_btn") submit_btn.click(predict, inputs=image_input, outputs=[prediction_output, confidence_output]) feedback_btn.click(report_feedback, outputs=feedback_output) # Clear button def clear_all(): return None, "", "", "" clear_btn.click(clear_all, outputs=[image_input, prediction_output, confidence_output, feedback_output]) if __name__ == "__main__": demo.launch(share=True)