---
base_model: IlyaGusev/saiga_llama3_8b
model_type: llama
pipeline_tag: text-generation
quantized_by: Compressa
language:
  - ru
license: other
license_name: llama3
license_link: https://llama.meta.com/llama3/license
tags:
  - saiga
  - llama3
  - omniquant
  - gptq
  - triton
---


# Saiga – Llama 3 8B – OmniQuant

Based on [Saiga Llama 3 8B](https://huggingface.co/IlyaGusev/saiga_llama3_8b).

Quantized with [OmniQuant](https://github.com/OpenGVLab/OmniQuant).


## Evaluation

### PPL (↓)

|               | wiki  |
| ------------- | ----- |
| FP            | 7,862 |
| **Quantized** | 8,615 |


### Accuracy on English Benchmarks, % (↑)

|               | piqa | arc_easy | arc_challenge | boolq | hellaswag | winogrande | mmlu_humanities | mmlu_social_sciences | mmlu_stem | mmlu_other |
| ------------- | ---- | -------- | ------------- | ----- | --------- | ---------- | --------------- | -------------------- | --------- | ---------- |
| FP            | 78,5 | 82,2     | 50,4          | 82,7  | 58,1      | 72,4       | 65,5            | 72,6                 | 53,8      | 68,4       |
| **Quantized** | 78,5 | 80,8     | 47,6          | 81,7  | 56,9      | 71,2       | 62,3            | 68,9                 | 49,7      | 63,3       |


### Accuracy on Russian Benchmarks, % (↑)

|               | danetqa | terra | rwsd | muserc | rucos | lidirus | parus | rcb  | russe | rucola |
| ------------- | ------- | ----- | ---- | ------ | ----- | ------- | ----- | ---- | ----- | ------ |
| FP            | 74,9    | 52,1  | 51,5 | 55,9   | 58,1  | 59,5    | 69,0  | 34,1 | 38,8  | 67,5   |
| **Quantized** | 65,4    | 50,5  | 49,5 | 60,7   | 53,7  | 50,9    | 71,0  | 33,6 | 40,8  | 67,5   |


### Summary

|               | Avg acc diff on Eng, % (↑) | Avg acc diff on Rus, % (↑) | Occupied disk space, % (↓) |
| ------------- | -------------------------- | -------------------------- | ---------------------- |
| FP            | 0                          | 0                          | 100                    |
| **Quantized** | \-2,4                      | \-1,8                      | 35,7                   |


## Examples

### Imports and Model Loading

<details>
  <summary>Expand</summary>
  
  ```python
  import gc

  import auto_gptq.nn_modules.qlinear.qlinear_cuda as qlinear_cuda
  import auto_gptq.nn_modules.qlinear.qlinear_triton as qlinear_triton
  import torch

  from accelerate import (
      init_empty_weights,
      infer_auto_device_map,
      load_checkpoint_in_model,
  )
  from tqdm import tqdm
  from transformers import (
      AutoConfig,
      AutoModelForCausalLM,
      AutoTokenizer,
      pipeline,
  )


  def get_named_linears(model):
      return {
          name: module for name, module in model.named_modules()
          if isinstance(module, torch.nn.Linear)
      }


  def set_module(model, name, module):
      parent = model
      levels = name.split('.')

      for i in range(len(levels) - 1):
          cur_name = levels[i]

          if cur_name.isdigit():
              parent = parent[int(cur_name)]
          else:
              parent = getattr(parent, cur_name)

      setattr(parent, levels[-1], module)


  def load_model(model_path):
      # Based on: https://github.com/OpenGVLab/OmniQuant/blob/main/runing_quantized_mixtral_7bx8.ipynb

      config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)

      if not hasattr(config, 'quantization_config'):
          raise AttributeError(
              f'No quantization info found in model config "{model_path}"'
              f' (`quantization_config` section is missing).'
          )

      wbits = config.quantization_config['bits']
      group_size = config.quantization_config['group_size']

      # We are going to init an ordinary model and then manually replace all Linears with QuantLinears
      del config.quantization_config

      with init_empty_weights():
          model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.float16, trust_remote_code=True)

      layers = model.model.layers

      for i in tqdm(range(len(layers))):
          layer = layers[i]
          named_linears = get_named_linears(layer)

          for name, module in named_linears.items():
              params = (
                  wbits, group_size,
                  module.in_features, module.out_features,
                  module.bias is not None
              )

              if wbits in [2, 4]:
                  q_linear = qlinear_triton.QuantLinear(*params)
              elif wbits == 3:
                  q_linear = qlinear_cuda.QuantLinear(*params)
              else:
                  raise NotImplementedError("Only 2, 3 and 4 bits are supported.")

              q_linear.to(next(layer.parameters()).device)
              set_module(layer, name, q_linear)

      torch.cuda.empty_cache()
      gc.collect()

      model.tie_weights()
      device_map = infer_auto_device_map(model)

      print("Loading pre-computed quantized weights...")

      load_checkpoint_in_model(
          model, checkpoint=model_path,
          device_map=device_map, offload_state_dict=True,
      )

      print("Model loaded successfully!")

      return model
  ```
</details>


### Inference

```python
model_path = "compressa-ai/Saiga-Llama-3-8B-OmniQuant"

model = load_model(model_path).cuda()
tokenizer = AutoTokenizer.from_pretrained(
    model_path, use_fast=False, trust_remote_code=True
)

system_message = "Ты — дружелюбный чат-бот, который всегда отвечает как пират."
user_message = "Куда мы направляемся, капитан?"
messages = [
    {"role": "system", "content": system_message},
    {"role": "user", "content": user_message},
]
prompt = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)

inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.cuda() for k, v in inputs.items()}

outputs = model.generate(
    **inputs, max_new_tokens=512,
    do_sample=True, temperature=0.7, top_p=0.95,
)

response = tokenizer.decode(outputs[0])
continuation = response.removeprefix(prompt).removesuffix(tokenizer.eos_token)

print(f'Prompt:\n{prompt}')
print(f'Continuation:\n{continuation}\n')
```


### Inference Using Pipeline

```python
pipe = pipeline(
    "text-generation",
    model=model, tokenizer=tokenizer,
    max_new_tokens=512, do_sample=True,
    temperature=0.7, top_p=0.95,
    device=0,
)

prompt = pipe.tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)

outputs = pipe(prompt)

response = outputs[0]["generated_text"]
continuation = response.removeprefix(prompt)

print(f'Prompt:\n{prompt}')
print(f'Continuation:\n{continuation}\n')
```