Didier commited on
Commit
5df3581
·
1 Parent(s): 9f5045e

Using BitsAndBytesConfig

Browse files
Files changed (1) hide show
  1. app.py +16 -1
app.py CHANGED
@@ -9,18 +9,31 @@ Date: 2024-09-07
9
  import spaces
10
  import torch
11
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
12
  import gradio as gr
13
 
14
  #
15
  # Load the "small" MADLAD400 model
16
  #
17
  model_name = "google/madlad400-10b-mt"
 
 
 
 
 
 
 
 
 
 
 
 
18
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
19
  model = AutoModelForSeq2SeqLM.from_pretrained(
20
  model_name,
21
  device_map="auto",
22
  torch_dtype=torch.float16,
23
- load_in_8bit=True)
24
  model = torch.compile(model)
25
 
26
  #
@@ -37,6 +50,8 @@ def translate_text(
37
  Input text will be split into chunk that will be translated sequentially.
38
  We will have up to sents_per_chunk sentences in a given chunk.
39
  """
 
 
40
  input_text = f"<2{tgt_lang}> {text}"
41
  input_ids = tokenizer(
42
  input_text, return_tensors="pt",
 
9
  import spaces
10
  import torch
11
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
12
+ from transformers import BitsAndBytesConfig
13
  import gradio as gr
14
 
15
  #
16
  # Load the "small" MADLAD400 model
17
  #
18
  model_name = "google/madlad400-10b-mt"
19
+
20
+ quantization_config = BitsAndBytesConfig(
21
+ load_in_4bit=True,
22
+ bnb_4bit_compute_dtype=torch.float16,
23
+ bnb_4bit_use_double_quant=True,
24
+ bnb_4bit_quant_type="nf4"
25
+ )
26
+ #quantization_config = BitsAndBytesConfig(
27
+ # load_in_8bit=True,
28
+ # llm_int8_threshold=200.0 # https://discuss.huggingface.co/t/correct-usage-of-bitsandbytesconfig/33809/5
29
+ #)
30
+
31
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
32
  model = AutoModelForSeq2SeqLM.from_pretrained(
33
  model_name,
34
  device_map="auto",
35
  torch_dtype=torch.float16,
36
+ quantization_config=quantization_config)
37
  model = torch.compile(model)
38
 
39
  #
 
50
  Input text will be split into chunk that will be translated sequentially.
51
  We will have up to sents_per_chunk sentences in a given chunk.
52
  """
53
+ if not tgt_lang:
54
+ tgt_lang = "en"
55
  input_text = f"<2{tgt_lang}> {text}"
56
  input_ids = tokenizer(
57
  input_text, return_tensors="pt",