File size: 2,054 Bytes
e50f30c
ec7d971
623c9e7
ec7d971
623c9e7
e50f30c
4a01533
 
ec7d971
 
 
191e2cd
e50f30c
 
 
 
ec7d971
e50f30c
 
ec7d971
e50f30c
 
adc05de
e50f30c
 
 
 
 
 
 
 
 
 
 
 
 
adc05de
e50f30c
ec7d971
 
623c9e7
e50f30c
1ee9cdc
ec7d971
 
 
 
e50f30c
ec7d971
623c9e7
 
 
4a01533
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
import gradio as gr
from PIL import Image

# Use a valid model identifier. Here we use "google/matcha-base".
model_name = "google/matcha-base"

# Load the pre-trained Pix2Struct model and processor
model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
processor = Pix2StructProcessor.from_pretrained(model_name)

# Move model to GPU if available for faster inference
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def solve_math_problem(image):
    # Preprocess the image and include a clear prompt.
    # You can adjust the prompt to better match your task if needed.
    inputs = processor(images=image, text="Solve the math problem:", return_tensors="pt")
    # Ensure the tensors are on the same device as the model
    inputs = {key: value.to(device) for key, value in inputs.items()}
    
    # Generate the solution using beam search.
    # Adjust parameters for best performance:
    # - max_new_tokens: Allows longer responses.
    # - num_beams: Uses beam search to explore multiple hypotheses.
    # - early_stopping: Stops decoding once a complete answer is generated.
    # - temperature: Controls randomness (lower value = more deterministic).
    predictions = model.generate(
        **inputs,
        max_new_tokens=150,
        num_beams=5,
        early_stopping=True,
        temperature=0.5
    )
    
    # Decode the output to get a string answer, skipping any special tokens.
    solution = processor.decode(predictions[0], skip_special_tokens=True)
    return solution

# Set up a Gradio interface
demo = gr.Interface(
    fn=solve_math_problem,
    inputs=gr.Image(type="pil", label="Upload Handwritten Math Problem"),
    outputs=gr.Textbox(label="Solution"),
    title="Handwritten Math Problem Solver",
    description="Upload an image of a handwritten math problem and the model will attempt to solve it.",
    theme="soft"
)

if __name__ == "__main__":
    demo.launch()