from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor import gradio as gr from PIL import Image # Load the pre-trained Pix2Struct model and processor model_name = "google/pix2struct-mathqa-base" model = Pix2StructForConditionalGeneration.from_pretrained(model_name) processor = Pix2StructProcessor.from_pretrained(model_name) # Function to solve handwritten math problems def solve_math_problem(image): # Preprocess the image inputs = processor(images=image, text="Solve the math problem:", return_tensors="pt") # Generate the solution predictions = model.generate(**inputs, max_new_tokens=100) # Decode the output solution = processor.decode(predictions[0], skip_special_tokens=True) return solution # Gradio interface demo = gr.Interface( fn=solve_math_problem, inputs=gr.Image(type="pil", label="Upload Handwritten Math Problem"), outputs=gr.Textbox(label="Solution"), title="Handwritten Math Problem Solver", description="Upload an image of a handwritten math problem, and the model will solve it.", examples=[ ["example1.jpg"], # Add example images ["example2.jpg"] ], theme="soft" ) # Launch the app if __name__ == "__main__": demo.launch()