Spaces:
Sleeping
Sleeping
File size: 3,661 Bytes
6079c6e b35805b d2c3421 b35805b 6079c6e b35805b 6079c6e b35805b 6079c6e b35805b d2c3421 b35805b d2c3421 6079c6e b35805b 6079c6e b35805b d2c3421 b35805b d2c3421 b35805b d2c3421 b35805b d2c3421 b35805b d2c3421 b35805b d2c3421 6079c6e b35805b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
BitsAndBytesConfig,
)
import os
from threading import Thread
import spaces
import time
token = os.environ["HF_TOKEN"]
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
)
model = AutoModelForCausalLM.from_pretrained(
"chheplo/sft_8b_2_llama3", quantization_config=quantization_config, token=token
)
tok = AutoTokenizer.from_pretrained("chheplo/sft_8b_2_llama3", token=token)
terminators = [
tok.eos_token_id,
tok.convert_tokens_to_ids("<|eot_id|>")
]
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
device = torch.device("cpu")
print("Using CPU")
# model = model.to(device)
# Dispatch Errors
@spaces.GPU()
def chat(message, history, temperature,do_sample, max_tokens):
prompt_template = """
You are a helpful Agricultural assistant for farmers. You are given the following input. Please complete the response briefly.
## Question:
{}
## Response:
{}"""
start_time = time.time()
chat = []
# for item in history:
# chat.append({"role": "user", "content": item[0]})
# if item[1] is not None:
# chat.append({"role": "assistant", "content": item[1]})
# chat.append({"role": "user", "content": message})
# messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
model_inputs = tok(prompt_template.format(
message, #input
"" # response
), return_tensors="pt").to(device)
streamer = TextIteratorStreamer(
tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
repetition_penalty=1.2,
use_cache=False,
eos_token_id=terminators,
)
if temperature == 0:
generate_kwargs['do_sample'] = False
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_text = ""
first_token_time = None
for new_text in streamer:
if not first_token_time:
first_token_time = time.time() - start_time
partial_text += new_text
yield partial_text
total_time = time.time() - start_time
tokens = len(tok.tokenize(partial_text))
tokens_per_second = tokens / total_time if total_time > 0 else 0
timing_info = f"\n\nTime taken to first token: {first_token_time:.2f} seconds\nTokens per second: {tokens_per_second:.2f}"
yield partial_text + timing_info
demo = gr.ChatInterface(
fn=chat,
examples=[["I'm a farmer from Odisha, how do I take care of whitefly in my cotton crop?"]],
# multimodal=False,
additional_inputs_accordion=gr.Accordion(
label="⚙️ Parameters", open=False, render=False
),
additional_inputs=[
gr.Slider(
minimum=0, maximum=1, step=0.1, value=0.5, label="Temperature", render=False
),
gr.Checkbox(label="Sampling",value=False),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=512,
label="Max new tokens",
render=False,
),
],
stop_btn="Stop Generation",
title="Chat With LLMs",
description="Now Running [KissanAI/llama3-8b-dhenu-0.1-sft-16bit](https://huggingface.co/KissanAI/llama3-8b-dhenu-0.1-sft-16bit) in 4bit")
demo.launch() |