from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor import gradio as gr from PIL import Image # Load the pre-trained Pix2Struct model and processor. model_name = "google/pix2struct-mathqa-base" model = Pix2StructForConditionalGeneration.from_pretrained(model_name) processor = Pix2StructProcessor.from_pretrained(model_name) def solve_math_problem(image): try: # Ensure the image is in RGB format. image = image.convert("RGB") # Preprocess the image and text. # Note: We omit the header_text parameter because this is not a VQA task. inputs = processor( images=[image], # Provide a list of images. text="Solve the following math problem:", # Prompt text. return_tensors="pt", max_patches=2048 # Increase the maximum patches for better math handling. ) # Generate the solution with specified generation parameters. predictions = model.generate( **inputs, max_new_tokens=200, early_stopping=True, num_beams=4, temperature=0.2 ) # Decode the input text and the model prediction. # Here, we access "input_ids" via the dictionary key. problem_text = processor.decode( inputs["input_ids"][0], skip_special_tokens=True, clean_up_tokenization_spaces=True ) solution = processor.decode( predictions[0], skip_special_tokens=True, clean_up_tokenization_spaces=True ) return f"Problem: {problem_text}\nSolution: {solution}" except Exception as e: return f"Error processing image: {str(e)}" # Set up the Gradio interface. demo = gr.Interface( fn=solve_math_problem, inputs=gr.Image( type="pil", label="Upload Handwritten Math Problem", image_mode="RGB", # Force RGB conversion. source="upload" ), outputs=gr.Textbox(label="Solution", show_copy_button=True), title="Handwritten Math Problem Solver", description="Upload an image of a handwritten math problem (algebra, arithmetic, etc.) and get the solution", examples=[ ["example_addition.png"], # Ensure these example files exist in your working directory. ["example_algebra.jpg"] ], theme="soft", allow_flagging="never" ) if __name__ == "__main__": demo.launch()