OmidSakaki commited on
Commit
12a2f23
·
verified ·
1 Parent(s): 4723eed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -42
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import gradio as gr
2
  import easyocr
3
  import numpy as np
4
- from typing import Tuple
5
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
 
7
  class OCRProcessor:
8
  def __init__(self):
@@ -13,66 +13,69 @@ class OCRProcessor:
13
  results = self.reader.readtext(image, detail=0, paragraph=True)
14
  return "\n".join(results) if results else ""
15
  except Exception as e:
16
- raise RuntimeError(f"خطا در پردازش OCR: {str(e)}")
17
 
18
- class TextCorrector:
19
  def __init__(self):
20
- model_name = "HooshvareLab/mt5-small-parsbert-uncased"
21
  try:
22
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
23
- self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
24
  except Exception as e:
25
- raise RuntimeError(f"خطا در بارگذاری مدل زبانی: {str(e)}")
26
 
27
- def correct(self, text: str) -> str:
28
- if not text.strip():
29
- return text
30
-
31
  try:
32
- prompt = "بازنویسی و تصحیح: " + text
33
- inputs = self.tokenizer(
34
- prompt,
35
- return_tensors="pt",
36
- max_length=512,
37
- truncation=True
38
  )
39
- outputs = self.model.generate(
40
- **inputs,
41
- max_length=512,
42
- num_beams=5,
43
- early_stopping=True
 
44
  )
45
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
46
  except Exception as e:
47
- print(f"خطا در تصحیح متن: {e}")
48
- return text
49
 
50
- def full_processing(image: np.ndarray) -> Tuple[str, str]:
51
- try:
52
- ocr_text = OCRProcessor().extract_text(image)
53
- corrected_text = TextCorrector().correct(ocr_text)
54
- return ocr_text, corrected_text
55
- except Exception as e:
56
- error_msg = f"خطا: {str(e)}"
57
- return error_msg, error_msg
58
 
59
- with gr.Blocks(title="پایپلاین OCR + تصحیح خودکار متن فارسی") as app:
 
 
 
 
 
 
 
60
  gr.Markdown("""
61
- # سیستم استخراج و تصحیح هوشمند متن فارسی
 
 
 
62
  """)
63
-
64
  with gr.Row():
65
  with gr.Column():
66
  img_input = gr.Image(label="تصویر ورودی", type="numpy")
67
- process_btn = gr.Button("پردازش تصویر", variant="primary")
 
68
  with gr.Column():
69
- raw_output = gr.Textbox(label="متن استخراج شده (خام)", lines=8, max_lines=None)
70
- corrected_output = gr.Textbox(label="متن تصحیح شده (هوشمند)", lines=15, max_lines=None)
71
 
72
  process_btn.click(
73
- fn=full_processing,
74
- inputs=img_input,
75
- outputs=[raw_output, corrected_output]
76
  )
77
 
78
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import easyocr
3
  import numpy as np
4
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
5
+ import torch
6
 
7
  class OCRProcessor:
8
  def __init__(self):
 
13
  results = self.reader.readtext(image, detail=0, paragraph=True)
14
  return "\n".join(results) if results else ""
15
  except Exception as e:
16
+ return f"خطا در پردازش OCR: {str(e)}"
17
 
18
+ class PersianQAModel:
19
  def __init__(self):
20
+ model_name = "OmidSakaki/roberta_Persian_QA"
21
  try:
22
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
23
+ self.model = AutoModelForQuestionAnswering.from_pretrained(model_name)
24
  except Exception as e:
25
+ raise RuntimeError(f"خطا در بارگذاری مدل پرسش و پاسخ: {str(e)}")
26
 
27
+ def answer_question(self, context: str, question: str) -> str:
28
+ if not context.strip() or not question.strip():
29
+ return "متن یا سوال وارد نشده است."
 
30
  try:
31
+ inputs = self.tokenizer.encode_plus(
32
+ question, context, return_tensors='pt', truncation=True, max_length=512
 
 
 
 
33
  )
34
+ input_ids = inputs["input_ids"].tolist()[0]
35
+ outputs = self.model(**inputs)
36
+ answer_start = torch.argmax(outputs.start_logits)
37
+ answer_end = torch.argmax(outputs.end_logits) + 1
38
+ answer = self.tokenizer.convert_tokens_to_string(
39
+ self.tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end])
40
  )
41
+ # حذف توکن‌های اضافی یا فاصله
42
+ answer = answer.strip()
43
+ if not answer or answer in ['[CLS]', '[SEP]', '[PAD]']:
44
+ return "جوابی یافت نشد."
45
+ return answer
46
  except Exception as e:
47
+ return f"خطا در مدل پرسش و پاسخ: {str(e)}"
 
48
 
49
+ ocr_processor = OCRProcessor()
50
+ qa_model = PersianQAModel()
 
 
 
 
 
 
51
 
52
+ def pipeline(image, question):
53
+ # استخراج متن از تصویر
54
+ context = ocr_processor.extract_text(image)
55
+ # پاسخ به سوال بر اساس متن
56
+ answer = qa_model.answer_question(context, question)
57
+ return context, answer
58
+
59
+ with gr.Blocks(title="استخراج متن و پاسخ به سوال از تصویر فارسی") as app:
60
  gr.Markdown("""
61
+ # سیستم هوشمند پرسش و پاسخ از روی تصویر فارسی
62
+ 1. تصویر را بارگذاری کنید تا متن استخراج شود.
63
+ 2. سوال خود را به فارسی تایپ کنید.
64
+ 3. دکمه «پاسخ» را بزنید.
65
  """)
 
66
  with gr.Row():
67
  with gr.Column():
68
  img_input = gr.Image(label="تصویر ورودی", type="numpy")
69
+ question_input = gr.Textbox(label="سوال شما به فارسی", placeholder="مثلاً: نویسنده این متن کیست؟", lines=1)
70
+ process_btn = gr.Button("پاسخ")
71
  with gr.Column():
72
+ context_output = gr.Textbox(label="متن استخراج شده", lines=10, max_lines=None, interactive=False)
73
+ answer_output = gr.Textbox(label="پاسخ مدل", lines=3, max_lines=None, interactive=False)
74
 
75
  process_btn.click(
76
+ fn=pipeline,
77
+ inputs=[img_input, question_input],
78
+ outputs=[context_output, answer_output]
79
  )
80
 
81
  if __name__ == "__main__":