TheOnlyHatem commited on
Commit
d6347df
·
verified ·
1 Parent(s): 006544e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -1,23 +1,24 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, T5ForConditionalGeneration
3
  import torch
4
 
5
- model_name = "DenoKuso/t5-small-gec-fr"
 
6
 
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = T5ForConditionalGeneration.from_pretrained(model_name)
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  model.to(device)
12
 
13
  def correction_grammaticale(texte):
14
- # Ce modèle utilise (a priori) le préfixe "gec: " pour la correction.
15
- input_text = "gec: " + texte
16
  input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
17
 
18
  outputs = model.generate(
19
  input_ids,
20
- max_length=128,
21
  num_beams=4,
22
  early_stopping=True
23
  )
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
 
5
+ # Utilisation du modèle sdadas/byt5-text-correction
6
+ model_name = "sdadas/byt5-text-correction"
7
 
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  model.to(device)
13
 
14
  def correction_grammaticale(texte):
15
+ # Pour le français, on ajoute le préfixe "<fr>" devant le texte
16
+ input_text = "<fr> " + texte
17
  input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
18
 
19
  outputs = model.generate(
20
  input_ids,
21
+ max_length=512,
22
  num_beams=4,
23
  early_stopping=True
24
  )