Moleys commited on
Commit
04aeedf
·
verified ·
1 Parent(s): 172828d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -6
app.py CHANGED
@@ -1,13 +1,36 @@
 
 
1
  import gradio as gr
2
- from transformers import pipeline
3
 
4
- # Load translation pipeline
5
- pipe = pipeline("translation", model="chi-vi/hirashiba-mt-tiny-zh-vi")
 
 
 
 
6
 
7
  def translate_text(input_text):
8
  lines = input_text.split('\n') # Tách từng dòng
9
- translated_lines = [pipe(line, max_length=512)[0]['translation_text'] if line.strip() else '' for line in lines]
10
- return '\n'.join(translated_lines) # Gộp lại với xuống dòng
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  if __name__ == '__main__':
13
  with gr.Blocks() as app:
@@ -18,7 +41,7 @@ if __name__ == '__main__':
18
  input_text = gr.Textbox(label='Input Chinese Text', lines=5, placeholder='Enter Chinese text here...')
19
  translate_button = gr.Button('Translate')
20
  output_text = gr.Textbox(label='Output Vietnamese Text', lines=5, interactive=False)
21
-
22
  translate_button.click(
23
  fn=translate_text,
24
  inputs=input_text,
 
1
+ import torch
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
  import gradio as gr
 
4
 
5
+ # Load model và tokenizer
6
+ model_name = "chi-vi/hirashiba-mt-tiny-zh-vi"
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
11
 
12
  def translate_text(input_text):
13
  lines = input_text.split('\n') # Tách từng dòng
14
+ translated_lines = []
15
+
16
+ for line in lines:
17
+ raw_text = line.strip()
18
+ if not raw_text:
19
+ translated_lines.append('') # Giữ dòng trống
20
+ continue
21
+
22
+ # Tokenize input
23
+ inputs = tokenizer(raw_text, return_tensors="pt", padding=True, truncation=True).to(device)
24
+
25
+ # Dịch với mô hình (không cần tính gradient)
26
+ with torch.no_grad():
27
+ output_tokens = model.generate(**inputs, max_length=512)
28
+
29
+ # Giải mã kết quả
30
+ translated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
31
+ translated_lines.append(translated_text)
32
+
33
+ return '\n'.join(translated_lines)
34
 
35
  if __name__ == '__main__':
36
  with gr.Blocks() as app:
 
41
  input_text = gr.Textbox(label='Input Chinese Text', lines=5, placeholder='Enter Chinese text here...')
42
  translate_button = gr.Button('Translate')
43
  output_text = gr.Textbox(label='Output Vietnamese Text', lines=5, interactive=False)
44
+
45
  translate_button.click(
46
  fn=translate_text,
47
  inputs=input_text,