TiberiuCristianLeon commited on
Commit
8ba2e89
·
verified ·
1 Parent(s): 2a41ea0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -30
app.py CHANGED
@@ -21,36 +21,38 @@ def download_argos_model(from_code, to_code):
21
  argostranslate.package.install_from_path(package_to_install.download())
22
 
23
  def wingpt(model_name, sl, tl, input_text):
24
- model = AutoModelForCausalLM.from_pretrained(
25
- model_name,
26
- torch_dtype="auto",
27
- device_map="auto"
28
- )
29
- tokenizer = AutoTokenizer.from_pretrained(model_name)
30
-
31
- messages = [
32
- {"role": "system", "content": f"Translate this from {sl} to {tl} language"},
33
- {"role": "user", "content": input_text}
34
- ]
35
-
36
- text = tokenizer.apply_chat_template(
37
- messages,
38
- tokenize=False,
39
- add_generation_prompt=False
40
- )
41
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
42
-
43
- generated_ids = model.generate(
44
- **model_inputs,
45
- max_new_tokens=512,
46
- temperature=0
47
- )
48
-
49
- generated_ids = [
50
- output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
51
- ]
52
-
53
- return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
54
 
55
  # App layout
56
  st.header("Text Machine Translation")
 
21
  argostranslate.package.install_from_path(package_to_install.download())
22
 
23
  def wingpt(model_name, sl, tl, input_text):
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ model_name,
26
+ torch_dtype="auto",
27
+ device_map="auto"
28
+ )
29
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
30
+
31
+ messages = [
32
+ {"role": "system", "content": f"Translate this from {sl} to {tl} language"},
33
+ {"role": "user", "content": input_text}
34
+ ]
35
+
36
+ text = tokenizer.apply_chat_template(
37
+ messages,
38
+ tokenize=False,
39
+ add_generation_prompt=False
40
+ )
41
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
42
+
43
+ generated_ids = model.generate(
44
+ **model_inputs,
45
+ max_new_tokens=512,
46
+ temperature=0
47
+ )
48
+
49
+ generated_ids = [
50
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
51
+ ]
52
+ print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True))
53
+ result = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
54
+
55
+ return result.replace(input_text, '').strip()
56
 
57
  # App layout
58
  st.header("Text Machine Translation")