mhtamim commited on
Commit
c4120c8
·
verified ·
1 Parent(s): 8e41d4c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+
5
+ import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+
10
+ DESCRIPTION = """\
11
+ # Sahabat-AI
12
+
13
+ Sahabat-AI (Indonesian language for “close friends”) is a collection of Large Language Models (LLMs) which has been pretrained and instruct-tuned for Indonesian language and its various dialects. Sahabat-AI ecosystem is co-initiated by Indonesian tech and telecommunication companies: GoTo Group and Indosat Ooredoo Hutchison.
14
+
15
+ Gemma2 9B CPT Sahabat-AI v1 Instruct is an Indonesian-focused model which has been fine-tuned with around 448,000 Indonesian instruction-completion pairs alongside an Indonesian-dialect pool consisting of 96,000 instruction-completion pairs in Javanese and 98,000 instruction-completion pairs in Sundanese. Additionally, we added a pool of 129,000 instruction-completion pairs in English.
16
+ """
17
+
18
+ MAX_MAX_NEW_TOKENS = 2048
19
+ DEFAULT_MAX_NEW_TOKENS = 1024
20
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
21
+
22
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
+
24
+ model_id = "GoToCompany/gemma2-9b-cpt-sahabatai-v1-instruct"
25
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ model_id,
28
+ device_map="auto",
29
+ torch_dtype=torch.bfloat16,
30
+ )
31
+ model.config.sliding_window = 4096
32
+ model.eval()
33
+
34
+
35
+ @spaces.GPU(duration=90)
36
+ def generate(
37
+ message: str,
38
+ chat_history: list[dict],
39
+ max_new_tokens: int = 1024,
40
+ temperature: float = 0.6,
41
+ top_p: float = 0.9,
42
+ top_k: int = 50,
43
+ repetition_penalty: float = 1.2,
44
+ ) -> Iterator[str]:
45
+ conversation = chat_history.copy()
46
+ conversation.append({"role": "user", "content": message})
47
+
48
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
49
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
50
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
51
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
52
+ input_ids = input_ids.to(model.device)
53
+
54
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
55
+ generate_kwargs = dict(
56
+ {"input_ids": input_ids},
57
+ streamer=streamer,
58
+ max_new_tokens=max_new_tokens,
59
+ do_sample=True,
60
+ top_p=top_p,
61
+ top_k=top_k,
62
+ temperature=temperature,
63
+ num_beams=1,
64
+ repetition_penalty=repetition_penalty,
65
+ )
66
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
67
+ t.start()
68
+
69
+ outputs = []
70
+ for text in streamer:
71
+ outputs.append(text)
72
+ yield "".join(outputs)
73
+
74
+
75
+ chat_interface = gr.ChatInterface(
76
+ fn=generate,
77
+ additional_inputs=[
78
+ gr.Slider(
79
+ label="Max new tokens",
80
+ minimum=1,
81
+ maximum=MAX_MAX_NEW_TOKENS,
82
+ step=1,
83
+ value=DEFAULT_MAX_NEW_TOKENS,
84
+ ),
85
+ gr.Slider(
86
+ label="Temperature",
87
+ minimum=0.1,
88
+ maximum=4.0,
89
+ step=0.1,
90
+ value=0.6,
91
+ ),
92
+ gr.Slider(
93
+ label="Top-p (nucleus sampling)",
94
+ minimum=0.05,
95
+ maximum=1.0,
96
+ step=0.05,
97
+ value=0.9,
98
+ ),
99
+ gr.Slider(
100
+ label="Top-k",
101
+ minimum=1,
102
+ maximum=1000,
103
+ step=1,
104
+ value=50,
105
+ ),
106
+ gr.Slider(
107
+ label="Repetition penalty",
108
+ minimum=1.0,
109
+ maximum=2.0,
110
+ step=0.05,
111
+ value=1.2,
112
+ ),
113
+ ],
114
+ stop_btn=None,
115
+ examples=[
116
+ ["Halo, apa kabar?"],
117
+ ["Bisakah anjeun ngajelaskeun singkat naon ari basa pamrograman Python?"],
118
+ ["Jelaskna cerita Cinderella ing sak ukara."],
119
+ ["How many hours does it take a man to eat a Helicopter?"],
120
+ ["Tulislah artikel 100 kata tentang 'Manfaat Open-Source dalam Penelitian AI."],
121
+ ],
122
+ cache_examples=False,
123
+ type="messages",
124
+ )
125
+
126
+ with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
127
+ gr.Markdown(DESCRIPTION)
128
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
129
+ chat_interface.render()
130
+
131
+ if __name__ == "__main__":
132
+ demo.queue(max_size=20).launch()