File size: 3,661 Bytes
6079c6e
b35805b
 
 
 
 
 
 
d2c3421
 
b35805b
 
6079c6e
b35805b
6079c6e
b35805b
 
 
6079c6e
b35805b
 
 
 
d2c3421
b35805b
 
d2c3421
6079c6e
b35805b
 
 
 
 
 
 
 
 
6079c6e
 
b35805b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2c3421
b35805b
d2c3421
b35805b
d2c3421
 
b35805b
 
d2c3421
 
b35805b
d2c3421
 
b35805b
d2c3421
 
6079c6e
b35805b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
    BitsAndBytesConfig,
)
import os
from threading import Thread
import spaces
import time

token = os.environ["HF_TOKEN"]

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
)

model = AutoModelForCausalLM.from_pretrained(
    "chheplo/sft_8b_2_llama3", quantization_config=quantization_config, token=token
)
tok = AutoTokenizer.from_pretrained("chheplo/sft_8b_2_llama3", token=token)
terminators = [
    tok.eos_token_id,
    tok.convert_tokens_to_ids("<|eot_id|>")
]

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
    device = torch.device("cpu")
    print("Using CPU")

# model = model.to(device)
# Dispatch Errors


@spaces.GPU()
def chat(message, history, temperature,do_sample, max_tokens):
    prompt_template = """
    You are a helpful Agricultural assistant for farmers. You are given the following input. Please complete the response briefly.
    ## Question:
    {}
    
    ## Response:
    {}"""
    start_time = time.time()
    chat = []
    # for item in history:
    #     chat.append({"role": "user", "content": item[0]})
    #     if item[1] is not None:
    #         chat.append({"role": "assistant", "content": item[1]})
    # chat.append({"role": "user", "content": message})
    # messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
    
    model_inputs = tok(prompt_template.format(
        message, #input
        "" # response
    ), return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(
        tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        do_sample=True,
        temperature=temperature,
        repetition_penalty=1.2, 
        use_cache=False,
        eos_token_id=terminators,
    )
    
    if temperature == 0:
        generate_kwargs['do_sample'] = False
    
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    partial_text = ""
    first_token_time = None
    for new_text in streamer:
        if not first_token_time:
            first_token_time = time.time() - start_time
        partial_text += new_text
        yield partial_text

    total_time = time.time() - start_time
    tokens = len(tok.tokenize(partial_text))
    tokens_per_second = tokens / total_time if total_time > 0 else 0

    timing_info = f"\n\nTime taken to first token: {first_token_time:.2f} seconds\nTokens per second: {tokens_per_second:.2f}"
    yield partial_text +  timing_info


demo = gr.ChatInterface(
    fn=chat,
    examples=[["I'm a farmer from Odisha, how do I take care of whitefly in my cotton crop?"]],
    # multimodal=False,
    additional_inputs_accordion=gr.Accordion(
        label="⚙️ Parameters", open=False, render=False
    ),
    additional_inputs=[
        gr.Slider(
            minimum=0, maximum=1, step=0.1, value=0.5, label="Temperature", render=False
        ),
        gr.Checkbox(label="Sampling",value=False),
        gr.Slider(
            minimum=128,
            maximum=4096,
            step=1,
            value=512,
            label="Max new tokens",
            render=False,
        ),
    ],
    stop_btn="Stop Generation",
    title="Chat With LLMs",
    description="Now Running [KissanAI/llama3-8b-dhenu-0.1-sft-16bit](https://huggingface.co/KissanAI/llama3-8b-dhenu-0.1-sft-16bit) in 4bit")
demo.launch()