kevinpro commited on
Commit
d447070
·
verified ·
1 Parent(s): e50fa51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py CHANGED
@@ -10,6 +10,26 @@ target_languages = flores_codes # 简化列表
10
 
11
  # 假设openai_client已定义,例如:
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  @lru_cache(maxsize=100)
15
  def translate(text: str, src_lang: str, tgt_lang: str):
@@ -19,6 +39,53 @@ def translate(text: str, src_lang: str, tgt_lang: str):
19
  raise gr.Error("The target language is empty! Please choose it in the dropdown list.")
20
  return _translate(text, src_lang, tgt_lang)
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def _translate(text: str, src_lang: str, tgt_lang: str):
23
  prompt = f"Translate the following text from {src_lang} to {tgt_lang}. Direct output translation result without any explaination:\n\n{text}"
24
  key=os.getenv('key')
 
10
 
11
  # 假设openai_client已定义,例如:
12
 
13
+ device = "cpu" if platform.system() == "Darwin" else "cuda"
14
+ MODEL_NAME = "ByteDance-Seed/Seed-X-PPO-7B"
15
+
16
+ code_mapping = dict(sorted(code_mapping.items(), key=lambda item: item[0]))
17
+ flores_codes = list(code_mapping.keys())
18
+ target_languages = [language for language in flores_codes if not language in REMOVED_TARGET_LANGUAGES]
19
+
20
+ def load_model():
21
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
22
+ print(f"Model loaded in {device}")
23
+ return model
24
+
25
+
26
+ model = load_model()
27
+
28
+
29
+ # Loading the tokenizer once, because re-loading it takes about 1.5 seconds each time
30
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
31
+
32
+
33
 
34
  @lru_cache(maxsize=100)
35
  def translate(text: str, src_lang: str, tgt_lang: str):
 
39
  raise gr.Error("The target language is empty! Please choose it in the dropdown list.")
40
  return _translate(text, src_lang, tgt_lang)
41
 
42
+ # Only assign GPU if cache not used
43
+ @spaces.GPU
44
+ def _translate(text: str, src_lang: str, tgt_lang: str):
45
+ src_code = code_mapping[src_lang]
46
+ tgt_code = code_mapping[tgt_lang]
47
+ tokenizer.src_lang = src_code
48
+ tokenizer.tgt_lang = tgt_code
49
+
50
+ # normalizing the punctuation first
51
+ text = punct_normalizer.normalize(text)
52
+
53
+ paragraphs = text.split("\n")
54
+ translated_paragraphs = []
55
+
56
+ for paragraph in paragraphs:
57
+ splitter = get_language_specific_sentence_splitter(src_code)
58
+ sentences = list(splitter(paragraph))
59
+ translated_sentences = []
60
+
61
+ for sentence in sentences:
62
+ input_tokens = (
63
+ tokenizer(sentence, return_tensors="pt")
64
+ .input_ids[0]
65
+ .cpu()
66
+ .numpy()
67
+ .tolist()
68
+ )
69
+ translated_chunk = model.generate(
70
+ input_ids=torch.tensor([input_tokens]).to(device),
71
+ forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code),
72
+ max_length=len(input_tokens) + 50,
73
+ num_return_sequences=1,
74
+ num_beams=5,
75
+ no_repeat_ngram_size=4, # repetition blocking works better if this number is below num_beams
76
+ renormalize_logits=True, # recompute token probabilities after banning the repetitions
77
+ )
78
+ translated_chunk = tokenizer.decode(
79
+ translated_chunk[0], skip_special_tokens=True
80
+ )
81
+ translated_sentences.append(translated_chunk)
82
+
83
+ translated_paragraph = " ".join(translated_sentences)
84
+ translated_paragraphs.append(translated_paragraph)
85
+
86
+ return "\n".join(translated_paragraphs)
87
+
88
+
89
  def _translate(text: str, src_lang: str, tgt_lang: str):
90
  prompt = f"Translate the following text from {src_lang} to {tgt_lang}. Direct output translation result without any explaination:\n\n{text}"
91
  key=os.getenv('key')