mfajcik commited on
Commit
cef3c43
·
verified ·
1 Parent(s): 6a6778e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -3
README.md CHANGED
@@ -143,9 +143,10 @@ input_args = {
143
  import torch
144
  import transformers
145
 
146
- model_name = "BUT-FIT/csmpt-6.7B-RAGsum"
147
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
148
  config = transformers.AutoConfig.from_pretrained(model_name, trust_remote_code=True)
 
149
 
150
  formatted_input = f"""Shrň následující výsledky pro dotaz "{input_args['query']}".
151
  |Výsledky|: {input_args['input']}
@@ -157,6 +158,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
157
  model = transformers.AutoModelForCausalLM.from_pretrained(
158
  model_name,
159
  config=config,
 
160
  trust_remote_code=True
161
  ).cuda()
162
  with torch.autocast('cuda', dtype=torch.bfloat16):
@@ -173,8 +175,6 @@ with torch.autocast('cuda', dtype=torch.bfloat16):
173
  input_length = inputs['input_ids'].shape[1]
174
  generated_text = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
175
 
176
- print("RAG Summary:", generated_text)
177
-
178
  ```
179
 
180
  Example of generated summary
 
143
  import torch
144
  import transformers
145
 
146
+ model_name = "BUT-FIT/csmpt-7B-RAGsum"
147
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
148
  config = transformers.AutoConfig.from_pretrained(model_name, trust_remote_code=True)
149
+ config.init_device = 'cuda:0' # For fast initialization directly on GPU!
150
 
151
  formatted_input = f"""Shrň následující výsledky pro dotaz "{input_args['query']}".
152
  |Výsledky|: {input_args['input']}
 
158
  model = transformers.AutoModelForCausalLM.from_pretrained(
159
  model_name,
160
  config=config,
161
+ torch_dtype=torch.bfloat16, # Load model weights in bfloat16
162
  trust_remote_code=True
163
  ).cuda()
164
  with torch.autocast('cuda', dtype=torch.bfloat16):
 
175
  input_length = inputs['input_ids'].shape[1]
176
  generated_text = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
177
 
 
 
178
  ```
179
 
180
  Example of generated summary