File size: 2,615 Bytes
9b26701
 
8bd9e0b
9b26701
e2a5fc6
9b26701
e2a5fc6
 
 
9b26701
e2a5fc6
9b26701
e2a5fc6
9b26701
8bd9e0b
 
 
e2a5fc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b26701
 
e2a5fc6
9b26701
 
 
e2a5fc6
 
 
 
 
 
 
 
9b26701
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# app.py
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
from sympy import sympify, solve, Eq, symbols

# Load the math OCR model and processor
processor = TrOCRProcessor.from_pretrained("nlpai-lab/mathocr-htr-base")
model = VisionEncoderDecoderModel.from_pretrained("nlpai-lab/mathocr-htr-base")

def predict_math_problem(image):
    try:
        # Transcribe the handwritten math problem
        image = image.convert("RGB")
        pixel_values = processor(image, return_tensors="pt").pixel_values
        generated_ids = model.generate(pixel_values)
        transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        
        # Standardize mathematical symbols in the transcription
        transcription = (transcription
            .replace("×", "*")
            .replace("÷", "/")
            .replace("−", "-")
            .replace("√", "sqrt")
            .replace("²", "**2")
            .replace("³", "**3")
            .replace("½", "1/2")
            .replace("¼", "1/4")
            .replace("…", "...")  # Ellipsis
        )
        
        # Attempt to solve the mathematical problem
        solution = None
        try:
            # Check if the transcription is an equation (contains '=')
            if '=' in transcription:
                lhs, rhs = transcription.split('=', 1)
                equation = Eq(sympify(lhs.strip()), sympify(rhs.strip()))
                variables = equation.free_symbols
                if variables:
                    variable = variables.pop()
                    solution = solve(equation, variable)
                    solution = f"{variable} = {solution}"
                else:
                    solution = "No variables found in equation"
            else:
                # Treat as an arithmetic expression
                solution = sympify(transcription)
                solution = f"Result: {solution}"
        except:
            solution = "Invalid or unsolvable expression"
        
        return transcription, solution
    
    except Exception as e:
        return f"Error: {str(e)}", "Failed to process"

# Create Gradio interface
demo = gr.Interface(
    fn=predict_math_problem,
    inputs=gr.Image(type="pil", label="Upload Handwritten Math Problem"),
    outputs=[
        gr.Textbox(label="Transcribed Text"),
        gr.Textbox(label="Solution")
    ],
    title="Handwritten Math Solver",
    description="Upload a handwritten math problem to get its transcription and solution."
)

if __name__ == "__main__":
    demo.launch()