OmidSakaki commited on
Commit
f774dbf
·
verified ·
1 Parent(s): 57fa964

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -17
app.py CHANGED
@@ -2,9 +2,8 @@ import gradio as gr
2
  import easyocr
3
  import numpy as np
4
  from typing import Tuple
5
- from transformers import pipeline
6
 
7
- # --- 1. کلاس OCR برای استخراج متن از تصویر ---
8
  class OCRProcessor:
9
  def __init__(self):
10
  self.reader = easyocr.Reader(['fa'])
@@ -16,32 +15,39 @@ class OCRProcessor:
16
  except Exception as e:
17
  raise RuntimeError(f"خطا در پردازش OCR: {str(e)}")
18
 
19
- # --- 2. کلاس تصحیح متن با مدل زبانی ---
20
  class TextCorrector:
21
  def __init__(self):
22
- # استفاده از مدل ParsBERT برای تصحیح متن فارسی
23
- self.corrector = pipeline(
24
- "text2text-generation",
25
- model="persiannlp/parsbert-uncased", # مدل زبانی فارسی
26
- tokenizer="persiannlp/parsbert-uncased"
27
- )
28
 
29
  def correct(self, text: str) -> str:
30
  if not text.strip():
31
  return text
 
32
  try:
33
- corrected = self.corrector(
34
- text,
 
 
 
 
 
 
 
35
  max_length=512,
36
  num_beams=5,
37
  early_stopping=True
38
  )
39
- return corrected[0]['generated_text']
 
40
  except Exception as e:
41
  print(f"خطا در تصحیح متن: {e}")
42
  return text
43
 
44
- # --- 3. پردازش کامل (OCR + تصحیح خودکار) ---
45
  def full_processing(image: np.ndarray) -> Tuple[str, str]:
46
  try:
47
  # استخراج متن از تصویر
@@ -55,10 +61,9 @@ def full_processing(image: np.ndarray) -> Tuple[str, str]:
55
  error_msg = f"خطا: {str(e)}"
56
  return error_msg, error_msg
57
 
58
- # --- 4. رابط کاربری Gradio ---
59
  with gr.Blocks(title="پایپلاین OCR + تصحیح خودکار متن فارسی") as app:
60
  gr.Markdown("""
61
- # استخراج و تصحیح هوشمند متن فارسی از تصویر
62
  """)
63
 
64
  with gr.Row():
@@ -66,8 +71,8 @@ with gr.Blocks(title="پایپلاین OCR + تصحیح خودکار متن فا
66
  img_input = gr.Image(label="تصویر ورودی", type="numpy")
67
  process_btn = gr.Button("پردازش تصویر", variant="primary")
68
  with gr.Column():
69
- raw_output = gr.Textbox(label="متن استخراج شده (خام)", lines=8, interactive=True)
70
- corrected_output = gr.Textbox(label="متن تصحیح شده (هوشمند)", lines=10, interactive=True)
71
 
72
  process_btn.click(
73
  fn=full_processing,
 
2
  import easyocr
3
  import numpy as np
4
  from typing import Tuple
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
 
 
7
  class OCRProcessor:
8
  def __init__(self):
9
  self.reader = easyocr.Reader(['fa'])
 
15
  except Exception as e:
16
  raise RuntimeError(f"خطا در پردازش OCR: {str(e)}")
17
 
 
18
  class TextCorrector:
19
  def __init__(self):
20
+ model_name = "persiannlp/mt5-small-parsinlu-arc-comqa-question"
21
+ try:
22
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
23
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
24
+ except Exception as e:
25
+ raise RuntimeError(f"خطا در بارگذاری مدل زبانی: {str(e)}")
26
 
27
  def correct(self, text: str) -> str:
28
  if not text.strip():
29
  return text
30
+
31
  try:
32
+ inputs = self.tokenizer(
33
+ "اصلاح متن: " + text,
34
+ return_tensors="pt",
35
+ max_length=512,
36
+ truncation=True
37
+ )
38
+
39
+ outputs = self.model.generate(
40
+ **inputs,
41
  max_length=512,
42
  num_beams=5,
43
  early_stopping=True
44
  )
45
+
46
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
47
  except Exception as e:
48
  print(f"خطا در تصحیح متن: {e}")
49
  return text
50
 
 
51
  def full_processing(image: np.ndarray) -> Tuple[str, str]:
52
  try:
53
  # استخراج متن از تصویر
 
61
  error_msg = f"خطا: {str(e)}"
62
  return error_msg, error_msg
63
 
 
64
  with gr.Blocks(title="پایپلاین OCR + تصحیح خودکار متن فارسی") as app:
65
  gr.Markdown("""
66
+ # سیستم استخراج و تصحیح هوشمند متن فارسی
67
  """)
68
 
69
  with gr.Row():
 
71
  img_input = gr.Image(label="تصویر ورودی", type="numpy")
72
  process_btn = gr.Button("پردازش تصویر", variant="primary")
73
  with gr.Column():
74
+ raw_output = gr.Textbox(label="متن استخراج شده (خام)", lines=8)
75
+ corrected_output = gr.Textbox(label="متن تصحیح شده (هوشمند)", lines=10)
76
 
77
  process_btn.click(
78
  fn=full_processing,