from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor import gradio as gr from PIL import Image # Use a public model identifier. If you need a private model, remember to authenticate. model_name = "google/pix2struct-textcaps-base" model = Pix2StructForConditionalGeneration.from_pretrained(model_name) processor = Pix2StructProcessor.from_pretrained(model_name) def solve_math_problem(image): try: # Ensure the image is in RGB format. image = image.convert("RGB") # Preprocess the image and text. Note that header_text is omitted as it's not used for non-VQA tasks. inputs = processor( images=[image], text="Solve the following math problem:", return_tensors="pt", max_patches=2048 ) # Generate the solution with generation parameters. predictions = model.generate( **inputs, max_new_tokens=200, early_stopping=True, num_beams=4, temperature=0.2 ) # Decode the problem text and generated solution. problem_text = processor.decode( inputs["input_ids"][0], skip_special_tokens=True, clean_up_tokenization_spaces=True ) solution = processor.decode( predictions[0], skip_special_tokens=True, clean_up_tokenization_spaces=True ) return f"Problem: {problem_text}\nSolution: {solution}" except Exception as e: return f"Error processing image: {str(e)}" # Set up the Gradio interface. demo = gr.Interface( fn=solve_math_problem, inputs=gr.Image( type="pil", label="Upload Handwritten Math Problem", image_mode="RGB" # This forces the input to be RGB. ), outputs=gr.Textbox(label="Solution", show_copy_button=True), title="Handwritten Math Problem Solver", description="Upload an image of a handwritten math problem (algebra, arithmetic, etc.) and get the solution", examples=[ ["example_addition.png"], ["example_algebra.jpg"] ], theme="soft", allow_flagging="never" ) if __name__ == "__main__": demo.launch()