| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						import json | 
					
					
						
						| 
							 | 
						import subprocess | 
					
					
						
						| 
							 | 
						from threading import Thread | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import spaces | 
					
					
						
						| 
							 | 
						import gradio as gr | 
					
					
						
						| 
							 | 
						from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						MODEL_ID = os.environ.get("MODEL_ID") | 
					
					
						
						| 
							 | 
						CHAT_TEMPLATE = os.environ.get("CHAT_TEMPLATE") | 
					
					
						
						| 
							 | 
						MODEL_NAME = MODEL_ID.split("/")[-1] | 
					
					
						
						| 
							 | 
						CONTEXT_LENGTH = int(os.environ.get("CONTEXT_LENGTH")) | 
					
					
						
						| 
							 | 
						COLOR = os.environ.get("COLOR") | 
					
					
						
						| 
							 | 
						EMOJI = os.environ.get("EMOJI") | 
					
					
						
						| 
							 | 
						DESCRIPTION = os.environ.get("DESCRIPTION") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						@spaces.GPU() | 
					
					
						
						| 
							 | 
						def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p): | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if CHAT_TEMPLATE == "ChatML": | 
					
					
						
						| 
							 | 
						        stop_tokens = ["<|endoftext|>", "<|im_end|>"] | 
					
					
						
						| 
							 | 
						        instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n' | 
					
					
						
						| 
							 | 
						        for human, assistant in history: | 
					
					
						
						| 
							 | 
						            instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant | 
					
					
						
						| 
							 | 
						        instruction += '\n<|im_start|>user\n' + message + '\n<|im_end|>\n<|im_start|>assistant\n' | 
					
					
						
						| 
							 | 
						    elif CHAT_TEMPLATE == "Mistral Instruct": | 
					
					
						
						| 
							 | 
						        stop_tokens = ["</s>", "[INST]", "[INST] ", "<s>", "[/INST]", "[/INST] "] | 
					
					
						
						| 
							 | 
						        instruction = '<s>[INST] ' + system_prompt | 
					
					
						
						| 
							 | 
						        for human, assistant in history: | 
					
					
						
						| 
							 | 
						            instruction += human + ' [/INST] ' + assistant + '</s>[INST]' | 
					
					
						
						| 
							 | 
						        instruction += ' ' + message + ' [/INST]' | 
					
					
						
						| 
							 | 
						    elif CHAT_TEMPLATE == "Bielik": | 
					
					
						
						| 
							 | 
						        stop_tokens = ["</s>"] | 
					
					
						
						| 
							 | 
						        prompt_builder = ["<s>"] | 
					
					
						
						| 
							 | 
						        for human, assistant in history: | 
					
					
						
						| 
							 | 
						            if system_prompt: | 
					
					
						
						| 
							 | 
						                prompt_builder.append(f"[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{human} [/INST] {assistant}</s>") | 
					
					
						
						| 
							 | 
						                system_prompt = None | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                prompt_builder.append(f"[INST] {human} [/INST] {assistant}</s>") | 
					
					
						
						| 
							 | 
						        prompt_builder.append(f"[INST] {message} [/INST]") | 
					
					
						
						| 
							 | 
						        instruction = ''.join(prompt_builder) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        raise Exception("Incorrect chat template, select 'ChatML' or 'Mistral Instruct'") | 
					
					
						
						| 
							 | 
						    print(instruction) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | 
					
					
						
						| 
							 | 
						    enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True) | 
					
					
						
						| 
							 | 
						    input_ids, attention_mask = enc.input_ids, enc.attention_mask | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if input_ids.shape[1] > CONTEXT_LENGTH: | 
					
					
						
						| 
							 | 
						        input_ids = input_ids[:, -CONTEXT_LENGTH:] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    generate_kwargs = dict( | 
					
					
						
						| 
							 | 
						        {"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)}, | 
					
					
						
						| 
							 | 
						        streamer=streamer, | 
					
					
						
						| 
							 | 
						        do_sample=True, | 
					
					
						
						| 
							 | 
						        temperature=temperature, | 
					
					
						
						| 
							 | 
						        max_new_tokens=max_new_tokens, | 
					
					
						
						| 
							 | 
						        top_k=top_k, | 
					
					
						
						| 
							 | 
						        repetition_penalty=repetition_penalty, | 
					
					
						
						| 
							 | 
						        top_p=top_p | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						    t = Thread(target=model.generate, kwargs=generate_kwargs) | 
					
					
						
						| 
							 | 
						    t.start() | 
					
					
						
						| 
							 | 
						    outputs = [] | 
					
					
						
						| 
							 | 
						    for new_token in streamer: | 
					
					
						
						| 
							 | 
						        outputs.append(new_token) | 
					
					
						
						| 
							 | 
						        if new_token in stop_tokens: | 
					
					
						
						| 
							 | 
						            break | 
					
					
						
						| 
							 | 
						        yield "".join(outputs) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | 
					
					
						
						| 
							 | 
						quantization_config = BitsAndBytesConfig( | 
					
					
						
						| 
							 | 
						    load_in_4bit=True, | 
					
					
						
						| 
							 | 
						    bnb_4bit_compute_dtype=torch.bfloat16 | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | 
					
					
						
						| 
							 | 
						tokenizer.pad_token = tokenizer.eos_token | 
					
					
						
						| 
							 | 
						model = AutoModelForCausalLM.from_pretrained( | 
					
					
						
						| 
							 | 
						    MODEL_ID, | 
					
					
						
						| 
							 | 
						    device_map="auto", | 
					
					
						
						| 
							 | 
						    quantization_config=quantization_config, | 
					
					
						
						| 
							 | 
						    attn_implementation="flash_attention_2", | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						gr.ChatInterface( | 
					
					
						
						| 
							 | 
						    predict, | 
					
					
						
						| 
							 | 
						    title=EMOJI + " " + MODEL_NAME, | 
					
					
						
						| 
							 | 
						    description=DESCRIPTION, | 
					
					
						
						| 
							 | 
						    examples=[ | 
					
					
						
						| 
							 | 
						        ["Kim jesteś?"], | 
					
					
						
						| 
							 | 
						        ["Ile to jest 9+2-1?"], | 
					
					
						
						| 
							 | 
						        ["Napisz mi coś miłego."] | 
					
					
						
						| 
							 | 
						    ], | 
					
					
						
						| 
							 | 
						    additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False), | 
					
					
						
						| 
							 | 
						    additional_inputs=[ | 
					
					
						
						| 
							 | 
						        gr.Textbox("Jesteś pomocnym asystentem o imieniu Bielik.", label="System prompt"), | 
					
					
						
						| 
							 | 
						        gr.Slider(0, 1, 0.6, label="Temperature"), | 
					
					
						
						| 
							 | 
						        gr.Slider(128, 4096, 1024, label="Max new tokens"), | 
					
					
						
						| 
							 | 
						        gr.Slider(1, 80, 40, label="Top K sampling"), | 
					
					
						
						| 
							 | 
						        gr.Slider(0, 2, 1.1, label="Repetition penalty"), | 
					
					
						
						| 
							 | 
						        gr.Slider(0, 1, 0.95, label="Top P sampling"), | 
					
					
						
						| 
							 | 
						    ], | 
					
					
						
						| 
							 | 
						    theme=gr.themes.Soft(primary_hue=COLOR), | 
					
					
						
						| 
							 | 
						).queue().launch() |