vilarin commited on
Commit
2fec857
·
verified ·
1 Parent(s): 9ad1f27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -17,7 +17,7 @@ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
17
 
18
  model = AutoModelForCausalLM.from_pretrained(
19
  MODEL,
20
- torch_dtype=torch.bfloat16,
21
  device_map="auto",
22
  quantization_config=quantization_config)
23
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
@@ -61,7 +61,7 @@ def translate(
61
  print(f'Text is - {source_text}')
62
 
63
  prompt = Prompt_template(source_text, source_lang, target_lang)
64
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids
65
 
66
  generate_kwargs = dict(
67
  input_ids=input_ids,
@@ -73,8 +73,10 @@ def translate(
73
  generate_ids = model.generate(**generate_kwargs)
74
 
75
  resp = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 
 
76
 
77
- return resp
78
 
79
  CSS = """
80
  h1 {
 
17
 
18
  model = AutoModelForCausalLM.from_pretrained(
19
  MODEL,
20
+ torch_dtype=torch.float16,
21
  device_map="auto",
22
  quantization_config=quantization_config)
23
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
 
61
  print(f'Text is - {source_text}')
62
 
63
  prompt = Prompt_template(source_text, source_lang, target_lang)
64
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
65
 
66
  generate_kwargs = dict(
67
  input_ids=input_ids,
 
73
  generate_ids = model.generate(**generate_kwargs)
74
 
75
  resp = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
76
+
77
+ print(resp)
78
 
79
+ yield resp
80
 
81
  CSS = """
82
  h1 {