Spaces:
Running
on
Zero
Running
on
Zero
| """Template Demo for IBM Granite Hugging Face spaces.""" | |
| from collections.abc import Iterator | |
| from datetime import datetime | |
| from pathlib import Path | |
| from threading import Thread | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from themes.carbon import carbon_theme | |
| today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002 | |
| SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024. | |
| Today's Date: {today_date}. | |
| You are Granite, developed by IBM. You are a helpful AI assistant""" | |
| TITLE = "IBM Granite 3.1 8b Instruct" | |
| DESCRIPTION = "Try one of the sample prompts below or write your own. Remember, just like developers, \ | |
| AI models can make mistakes." | |
| MAX_INPUT_TOKEN_LENGTH = 128_000 | |
| MAX_NEW_TOKENS = 1024 | |
| TEMPERATURE = 0.7 | |
| TOP_P = 0.85 | |
| TOP_K = 50 | |
| REPETITION_PENALTY = 1.05 | |
| if not torch.cuda.is_available(): | |
| DESCRIPTION += "\nThis demo does not work on CPU." | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "ibm-granite/granite-3.1-8b-instruct", torch_dtype=torch.float16, device_map="auto" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-3.1-8b-instruct") | |
| tokenizer.use_default_system_prompt = False | |
| def generate(message: str, chat_history: list[dict]) -> Iterator[str]: | |
| """Generate function for chat demo.""" | |
| # Build messages | |
| conversation = [] | |
| conversation.append({"role": "system", "content": SYS_PROMPT}) | |
| conversation += chat_history | |
| conversation.append({"role": "user", "content": message}) | |
| # Convert messages to prompt format | |
| input_ids = tokenizer.apply_chat_template( | |
| conversation, | |
| return_tensors="pt", | |
| add_generation_prompt=True, | |
| truncation=True, | |
| max_length=MAX_INPUT_TOKEN_LENGTH, | |
| ) | |
| input_ids = input_ids.to(model.device) | |
| streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| {"input_ids": input_ids}, | |
| streamer=streamer, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=True, | |
| top_p=TOP_P, | |
| top_k=TOP_K, | |
| temperature=TEMPERATURE, | |
| num_beams=1, | |
| repetition_penalty=REPETITION_PENALTY, | |
| ) | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| outputs = [] | |
| for text in streamer: | |
| outputs.append(text) | |
| yield "".join(outputs) | |
| css_file_path = Path(Path(__file__).parent / "app.css") | |
| head_file_path = Path(Path(__file__).parent / "app_head.html") | |
| with gr.Blocks( | |
| fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=carbon_theme, title=TITLE | |
| ) as demo: | |
| gr.HTML( | |
| f"<img src='https://www.ibm.com/granite/docs/images/granite-cubes-352x368.webp'/><h1>{TITLE}</h1>", | |
| elem_classes=["gr_title"], | |
| ) | |
| gr.HTML(DESCRIPTION) | |
| gr.HTML( | |
| value='<a href="https://www.ibm.com/granite/docs/">View Documentation</a> <i class="fa fa-external-link"></i>', | |
| elem_classes=["gr_docs_link"], | |
| ) | |
| chat_interface = gr.ChatInterface( | |
| fn=generate, | |
| examples=[ | |
| ["Explain quantum computing"], | |
| ["What is OpenShift?"], | |
| ["Importance of low latency inference"], | |
| ["Boosting productivity habits"], | |
| ], | |
| cache_examples=False, | |
| type="messages", | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |