TheOnlyHatem commited on
Commit
405c6a7
·
verified ·
1 Parent(s): d6347df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -1,28 +1,30 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
 
4
 
5
- # Utilisation du modèle sdadas/byt5-text-correction
6
- model_name = "sdadas/byt5-text-correction"
7
 
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  model.to(device)
13
 
14
  def correction_grammaticale(texte):
15
- # Pour le français, on ajoute le préfixe "<fr>" devant le texte
16
- input_text = "<fr> " + texte
17
- input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
18
 
 
19
  outputs = model.generate(
20
- input_ids,
21
  max_length=512,
22
  num_beams=4,
23
  early_stopping=True
24
  )
25
 
 
26
  correction = tokenizer.decode(outputs[0], skip_special_tokens=True)
27
  return correction
28
 
@@ -30,7 +32,7 @@ demo = gr.Interface(
30
  fn=correction_grammaticale,
31
  inputs=gr.Textbox(label="Texte à corriger"),
32
  outputs=gr.Textbox(label="Texte corrigé"),
33
- title="Correcteur de Texte Français"
34
  )
35
 
36
  if __name__ == "__main__":
 
1
  import gradio as gr
 
2
  import torch
3
+ from transformers import MBartForConditionalGeneration, MBartTokenizer
4
 
5
+ # Remplace par ton repo exact si besoin :
6
+ model_name = "alice/mini/mBART_french_correction"
7
 
8
+ # Chargement du tokenizer et du modèle
9
+ tokenizer = MBartTokenizer.from_pretrained(model_name)
10
+ model = MBartForConditionalGeneration.from_pretrained(model_name)
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  model.to(device)
14
 
15
  def correction_grammaticale(texte):
16
+ # Tokenisation
17
+ inputs = tokenizer(texte, return_tensors="pt", max_length=512, truncation=True).to(device)
 
18
 
19
+ # Génération
20
  outputs = model.generate(
21
+ **inputs,
22
  max_length=512,
23
  num_beams=4,
24
  early_stopping=True
25
  )
26
 
27
+ # Décodage
28
  correction = tokenizer.decode(outputs[0], skip_special_tokens=True)
29
  return correction
30
 
 
32
  fn=correction_grammaticale,
33
  inputs=gr.Textbox(label="Texte à corriger"),
34
  outputs=gr.Textbox(label="Texte corrigé"),
35
+ title="Correcteur MBART Français"
36
  )
37
 
38
  if __name__ == "__main__":