Nitin00043 commited on
Commit
e2a5fc6
·
verified ·
1 Parent(s): 8bd9e0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -17
app.py CHANGED
@@ -2,35 +2,69 @@
2
  import gradio as gr
3
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
4
  from PIL import Image
 
5
 
6
- # Load model and processor
7
- processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
8
- model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
9
 
10
- def predict_handwriting(image):
11
- """
12
- Function to process handwritten text image and return transcription
13
- """
14
  try:
15
- # Preprocess the image
16
  image = image.convert("RGB")
17
- # Prepare image pixel values
18
  pixel_values = processor(image, return_tensors="pt").pixel_values
19
- # Generate text
20
  generated_ids = model.generate(pixel_values)
21
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
22
- return transcription
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  except Exception as e:
25
- return f"Error processing image: {str(e)}"
26
 
27
  # Create Gradio interface
28
  demo = gr.Interface(
29
- fn=predict_handwriting,
30
- inputs=gr.Image(type="pil", label="Upload Handwritten Text Image"),
31
- outputs=gr.Textbox(label="Transcription"),
32
- title="Handwritten Text to Text Converter",
33
- description="Upload a handwritten text image and get the transcribed text. Best results with clear, high-contrast images."
 
 
 
34
  )
35
 
36
  if __name__ == "__main__":
 
2
  import gradio as gr
3
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
4
  from PIL import Image
5
+ from sympy import sympify, solve, Eq, symbols
6
 
7
+ # Load the math OCR model and processor
8
+ processor = TrOCRProcessor.from_pretrained("nlpai-lab/mathocr-htr-base")
9
+ model = VisionEncoderDecoderModel.from_pretrained("nlpai-lab/mathocr-htr-base")
10
 
11
+ def predict_math_problem(image):
 
 
 
12
  try:
13
+ # Transcribe the handwritten math problem
14
  image = image.convert("RGB")
 
15
  pixel_values = processor(image, return_tensors="pt").pixel_values
 
16
  generated_ids = model.generate(pixel_values)
17
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
18
+
19
+ # Standardize mathematical symbols in the transcription
20
+ transcription = (transcription
21
+ .replace("×", "*")
22
+ .replace("÷", "/")
23
+ .replace("−", "-")
24
+ .replace("√", "sqrt")
25
+ .replace("²", "**2")
26
+ .replace("³", "**3")
27
+ .replace("½", "1/2")
28
+ .replace("¼", "1/4")
29
+ .replace("…", "...") # Ellipsis
30
+ )
31
+
32
+ # Attempt to solve the mathematical problem
33
+ solution = None
34
+ try:
35
+ # Check if the transcription is an equation (contains '=')
36
+ if '=' in transcription:
37
+ lhs, rhs = transcription.split('=', 1)
38
+ equation = Eq(sympify(lhs.strip()), sympify(rhs.strip()))
39
+ variables = equation.free_symbols
40
+ if variables:
41
+ variable = variables.pop()
42
+ solution = solve(equation, variable)
43
+ solution = f"{variable} = {solution}"
44
+ else:
45
+ solution = "No variables found in equation"
46
+ else:
47
+ # Treat as an arithmetic expression
48
+ solution = sympify(transcription)
49
+ solution = f"Result: {solution}"
50
+ except:
51
+ solution = "Invalid or unsolvable expression"
52
+
53
+ return transcription, solution
54
 
55
  except Exception as e:
56
+ return f"Error: {str(e)}", "Failed to process"
57
 
58
  # Create Gradio interface
59
  demo = gr.Interface(
60
+ fn=predict_math_problem,
61
+ inputs=gr.Image(type="pil", label="Upload Handwritten Math Problem"),
62
+ outputs=[
63
+ gr.Textbox(label="Transcribed Text"),
64
+ gr.Textbox(label="Solution")
65
+ ],
66
+ title="Handwritten Math Solver",
67
+ description="Upload a handwritten math problem to get its transcription and solution."
68
  )
69
 
70
  if __name__ == "__main__":