grahamwhiteuk's picture
fix: latest round of design feedback
5ce66fb
raw
history blame
7.18 kB
"""Template Demo for IBM Granite Hugging Face spaces."""
from collections.abc import Iterator
from datetime import datetime
from pathlib import Path
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from themes.research_monochrome import theme
today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002
SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024.
Today's Date: {today_date}.
You are Granite, developed by IBM. You are a helpful AI assistant"""
TITLE = "IBM Granite 3.1 8b Instruct"
DESCRIPTION = """
<p>Granite 3.1 is a general purpose large language model released in the open under an Apache 2.0 license. Granite
models support a 128k context length. Try one of the sample prompts below or write your own. Remember, AI models can
make mistakes.
<span class="gr_docs_link">
<a href="https://www.ibm.com/granite/docs/">View Documentation <i class="fa fa-external-link"></i></a>
</span>
</p>
"""
MAX_INPUT_TOKEN_LENGTH = 128_000
MAX_NEW_TOKENS = 1024
TEMPERATURE = 0.7
TOP_P = 0.85
TOP_K = 50
REPETITION_PENALTY = 1.05
if not torch.cuda.is_available():
print("This demo may not work on CPU.")
model = AutoModelForCausalLM.from_pretrained(
"ibm-granite/granite-3.1-8b-instruct", torch_dtype=torch.float16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-3.1-8b-instruct")
tokenizer.use_default_system_prompt = False
@spaces.GPU
def generate(
message: str,
chat_history: list[dict],
temperature: float = TEMPERATURE,
top_p: float = TOP_P,
top_k: float = TOP_K,
repetition_penalty: float = REPETITION_PENALTY,
max_new_tokens: int = MAX_NEW_TOKENS,
) -> Iterator[str]:
"""Generate function for chat demo."""
# Build messages
conversation = []
conversation.append({"role": "system", "content": SYS_PROMPT})
conversation += chat_history
conversation.append({"role": "user", "content": message})
# Convert messages to prompt format
input_ids = tokenizer.apply_chat_template(
conversation,
return_tensors="pt",
add_generation_prompt=True,
truncation=True,
max_length=MAX_INPUT_TOKEN_LENGTH - max_new_tokens,
)
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
css_file_path = Path(Path(__file__).parent / "app.css")
head_file_path = Path(Path(__file__).parent / "app_head.html")
# advanced settings (displayed in Accordion)
temperature_slider = gr.Slider(
minimum=0, maximum=1.0, value=TEMPERATURE, step=0.1, label="Temperature", elem_classes=["gr_accordion_element"]
)
top_p_slider = gr.Slider(
minimum=0, maximum=1.0, value=TOP_P, step=0.05, label="Top P", elem_classes=["gr_accordion_element"]
)
top_k_slider = gr.Slider(
minimum=0, maximum=100, value=TOP_K, step=1, label="Top K", elem_classes=["gr_accordion_element"]
)
repetition_penalty_slider = gr.Slider(
minimum=0,
maximum=2.0,
value=REPETITION_PENALTY,
step=0.1,
label="Repetition Penalty",
elem_classes=["gr_accordion_element"],
)
max_new_tokens_slider = gr.Slider(
minimum=1,
maximum=2000,
value=MAX_NEW_TOKENS,
step=1,
label="Max New Tokens",
elem_classes=["gr_accordion_element"],
)
chat_interface_accordion = gr.Accordion(label="Advanced Settings", open=False)
with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=theme, title=TITLE) as demo:
gr.HTML(f"<h1>{TITLE}</h1>", elem_classes=["gr_title"])
gr.HTML(DESCRIPTION)
chat_interface = gr.ChatInterface(
fn=generate,
examples=[
["Explain the concept of quantum computing to someone with no background in physics or computer science."],
["What is OpenShift?"],
["Why is low latency inference important?"],
[
"""Imagine you're a productivity expert. How can I develop effective habits to increase my productivity and efficiency? Provide me with specific strategies and actionable steps I can implement in my daily routine.""" # noqa: E501
],
[
"""Explain the following code in a concise manner:
```java
import java.util.ArrayList;
import java.util.List;
public class Main {
public static void main(String[] args) {
int[] arr = {1, 5, 3, 4, 2};
int diff = 3;
List<Pair> pairs = findPairs(arr, diff);
for (Pair pair : pairs) {
System.out.println(pair.x + " " + pair.y);
}
}
public static List<Pair> findPairs(int[] arr, int diff) {
List<Pair> pairs = new ArrayList<>();
for (int i = 0; i < arr.length; i++) {
for (int j = i + 1; j < arr.length; j++) {
if (Math.abs(arr[i] - arr[j]) < diff) {
pairs.add(new Pair(arr[i], arr[j]));
}
}
}
return pairs;
}
}
class Pair {
int x;
int y;
public Pair(int x, int y) {
this.x = x;
this.y = y;
}
}
```"""
],
[
"""Generate a Java code block from the following explanation:
The code in the Main class finds all pairs in an array whose absolute difference is less than a given value.
The findPairs method takes two arguments: an array of integers and a difference value. It iterates over the array and compares each element to every other element in the array. If the absolute difference between the two elements is less than the difference value, a new Pair object is created and added to a list.
The Pair class is a simple data structure that stores two integers.
The main method creates an array of integers, initializes the difference value, and calls the findPairs method to find all pairs in the array. Finally, the code iterates over the list of pairs and prints each pair to the console.""" # noqa: E501
],
],
example_labels=[
"Explain quantum computing",
"What is OpenShift?",
"Importance of low latency inference",
"Boosting productivity habits",
"Explain and document your code",
"Generate Java Code",
],
cache_examples=False,
type="messages",
additional_inputs=[
temperature_slider,
repetition_penalty_slider,
top_p_slider,
top_k_slider,
max_new_tokens_slider,
],
additional_inputs_accordion=chat_interface_accordion,
)
if __name__ == "__main__":
demo.queue().launch()