File size: 3,262 Bytes
b376f12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""Template Demo for IBM Granite Hugging Face spaces."""

from collections.abc import Iterator
from datetime import datetime
from pathlib import Path
from threading import Thread

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

from themes.carbon import carbon_theme

today_date = datetime.today().strftime("%B %-d, %Y")  # noqa: DTZ002

MODEL_ID = "ibm-granite/granite-3.1-8b-instruct"
SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024.
Today's Date: {today_date}.
You are Granite, developed by IBM. You are a helpful AI assistant"""
TITLE = "IBM Granite 3.1 8b Instruct"
DESCRIPTION = "Try one of the sample prompts below or write your own. Remember, just like developers, \
               AI models can make mistakes."
MAX_INPUT_TOKEN_LENGTH = 4096
MAX_NEW_TOKENS = 1024
TEMPERATURE = 0.7
TOP_P = 0.85
TOP_K = 50
REPETITION_PENALTY = 1.05

if not torch.cuda.is_available():
    DESCRIPTION += "\nThis demo does not work on CPU."

model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.use_default_system_prompt = False


@spaces.GPU
def generate(message: str, chat_history: list[dict]) -> Iterator[str]:
    """Generate function for chat demo."""
    # Build messages
    conversation = []
    conversation.append({"role": "system", "content": SYS_PROMPT})
    conversation += chat_history
    conversation.append({"role": "user", "content": message})

    # Convert messages to prompt format
    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True)

    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")

    input_ids = input_ids.to(model.device)
    streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=True,
        top_p=TOP_P,
        top_k=TOP_K,
        temperature=TEMPERATURE,
        num_beams=1,
        repetition_penalty=REPETITION_PENALTY,
    )

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


chat_interface = gr.ChatInterface(
    fn=generate,
    stop_btn=None,
    examples=[
        ["Explain quantum computing"],
        ["What is OpenShift?"],
        ["Importance of low latency inference"],
        ["Boosting productivity habits"],
    ],
    cache_examples=False,
    type="messages",
)

css_file_path = Path(Path(__file__).parent / "app.css")
head_file_path = Path(Path(__file__).parent / "app_head.html")

with gr.Blocks(
    fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=carbon_theme, title=TITLE
) as demo:
    gr.Markdown(f"# {TITLE}")
    gr.Markdown(DESCRIPTION)
    chat_interface.render()

if __name__ == "__main__":
    demo.queue(max_size=20).launch()