SuperSl6 commited on
Commit
05cb4d5
·
verified ·
1 Parent(s): 416370c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -5
app.py CHANGED
@@ -1,6 +1,7 @@
1
  from transformers import pipeline, AutoTokenizer
2
  import gradio as gr
3
  import re
 
4
 
5
  # Load tokenizer with use_fast=False
6
  tokenizer = AutoTokenizer.from_pretrained("SuperSl6/Arabic-Text-Correction", use_fast=False)
@@ -10,19 +11,49 @@ model = pipeline(
10
  tokenizer=tokenizer
11
  )
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def correct_text(input_text):
14
  result = model(
15
  input_text,
16
  max_length=50,
17
  no_repeat_ngram_size=2,
18
  repetition_penalty=1.5,
19
- num_return_sequences=1
 
 
 
20
  )[0]['generated_text']
21
 
22
- # Extract the first occurrence of corrected Arabic word(s)
23
- matches = re.findall(r'[\u0600-\u06FF]+', result)
24
- corrected_text = matches[0] if matches else result
25
-
26
  return corrected_text
27
 
28
  # Gradio Interface
 
1
  from transformers import pipeline, AutoTokenizer
2
  import gradio as gr
3
  import re
4
+ import difflib
5
 
6
  # Load tokenizer with use_fast=False
7
  tokenizer = AutoTokenizer.from_pretrained("SuperSl6/Arabic-Text-Correction", use_fast=False)
 
11
  tokenizer=tokenizer
12
  )
13
 
14
+ def extract_corrected_version(original, generated):
15
+ # Split generated text into sentences
16
+ sentences = generated.split(' . ')
17
+
18
+ # Find the sentence most similar to the original
19
+ best_match = max(sentences, key=lambda s: difflib.SequenceMatcher(None, original, s).ratio())
20
+
21
+ # Extract the corrected Arabic words
22
+ corrected_words = re.findall(r'[\u0600-\u06FF]+', best_match)
23
+
24
+ # If no corrections found, return the original input
25
+ if not corrected_words:
26
+ return original
27
+
28
+ # Check if the corrected text is a proper subset of the generated text
29
+ corrected_text = ' '.join(corrected_words)
30
+ if corrected_text in best_match:
31
+ # Check if the corrected text is the complete output
32
+ if corrected_text == best_match.strip():
33
+ return corrected_text
34
+ else:
35
+ # If not the complete output, find the shortest corrected phrase
36
+ for i in range(len(corrected_words), 0, -1):
37
+ phrase = ' '.join(corrected_words[:i])
38
+ if phrase in best_match:
39
+ return phrase
40
+ # If no corrected phrase is found, return the original input
41
+ return original
42
+
43
  def correct_text(input_text):
44
  result = model(
45
  input_text,
46
  max_length=50,
47
  no_repeat_ngram_size=2,
48
  repetition_penalty=1.5,
49
+ num_return_sequences=1,
50
+ temperature=0.7,
51
+ top_p=0.9,
52
+ do_sample=True
53
  )[0]['generated_text']
54
 
55
+ # Extract the corrected version
56
+ corrected_text = extract_corrected_version(input_text, result)
 
 
57
  return corrected_text
58
 
59
  # Gradio Interface