OmidSakaki commited on
Commit
c55b6a8
·
verified ·
1 Parent(s): 77d9d02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -8
app.py CHANGED
@@ -17,7 +17,7 @@ class OCRProcessor:
17
 
18
  class TextCorrector:
19
  def __init__(self):
20
- model_name = "persiannlp/mt5-small-parsinlu-arc-comqa-question"
21
  try:
22
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
23
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
@@ -29,20 +29,19 @@ class TextCorrector:
29
  return text
30
 
31
  try:
 
32
  inputs = self.tokenizer(
33
- "اصلاح متن: " + text,
34
  return_tensors="pt",
35
  max_length=512,
36
  truncation=True
37
  )
38
-
39
  outputs = self.model.generate(
40
  **inputs,
41
  max_length=512,
42
  num_beams=5,
43
  early_stopping=True
44
  )
45
-
46
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
47
  except Exception as e:
48
  print(f"خطا در تصحیح متن: {e}")
@@ -50,12 +49,8 @@ class TextCorrector:
50
 
51
  def full_processing(image: np.ndarray) -> Tuple[str, str]:
52
  try:
53
- # استخراج متن از تصویر
54
  ocr_text = OCRProcessor().extract_text(image)
55
-
56
- # تصحیح متن با مدل زبانی
57
  corrected_text = TextCorrector().correct(ocr_text)
58
-
59
  return ocr_text, corrected_text
60
  except Exception as e:
61
  error_msg = f"خطا: {str(e)}"
 
17
 
18
  class TextCorrector:
19
  def __init__(self):
20
+ model_name = "HooshvareLab/mt5-small-fa"
21
  try:
22
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
23
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
29
  return text
30
 
31
  try:
32
+ prompt = "تصحیح نگارشی متن: " + text
33
  inputs = self.tokenizer(
34
+ prompt,
35
  return_tensors="pt",
36
  max_length=512,
37
  truncation=True
38
  )
 
39
  outputs = self.model.generate(
40
  **inputs,
41
  max_length=512,
42
  num_beams=5,
43
  early_stopping=True
44
  )
 
45
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
46
  except Exception as e:
47
  print(f"خطا در تصحیح متن: {e}")
 
49
 
50
  def full_processing(image: np.ndarray) -> Tuple[str, str]:
51
  try:
 
52
  ocr_text = OCRProcessor().extract_text(image)
 
 
53
  corrected_text = TextCorrector().correct(ocr_text)
 
54
  return ocr_text, corrected_text
55
  except Exception as e:
56
  error_msg = f"خطا: {str(e)}"