jameshhjung commited on
Commit
0f57722
ยท
verified ยท
1 Parent(s): f35fdc5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -57
app.py CHANGED
@@ -1,64 +1,156 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+ import gc
5
 
6
+ class ModelManager:
7
+ def __init__(self):
8
+ self.model = None
9
+ self.tokenizer = None
10
+ self.model_name = "CohereForAI/c4ai-command-r-plus-4bit"
11
+
12
+ def load_model(self):
13
+ if self.model is None:
14
+ try:
15
+ print("๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘... ์‹œ๊ฐ„์ด ๊ฑธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
16
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
17
+ self.model = AutoModelForCausalLM.from_pretrained(
18
+ self.model_name,
19
+ torch_dtype=torch.float16,
20
+ device_map="auto",
21
+ trust_remote_code=True,
22
+ load_in_4bit=True,
23
+ low_cpu_mem_usage=True
24
+ )
25
+ print("๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ!")
26
+ return True
27
+ except Exception as e:
28
+ print(f"๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ: {e}")
29
+ return False
30
+ return True
31
+
32
+ def generate(self, message, history, max_tokens=1000, temperature=0.7):
33
+ if not self.load_model():
34
+ return "๋ชจ๋ธ ๋กœ๋”ฉ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค."
35
+
36
+ try:
37
+ # ์ฑ„ํŒ… ํžˆ์Šคํ† ๋ฆฌ ๊ตฌ์„ฑ
38
+ conversation = []
39
+ for human, assistant in history:
40
+ conversation.append({"role": "user", "content": human})
41
+ if assistant:
42
+ conversation.append({"role": "assistant", "content": assistant})
43
+ conversation.append({"role": "user", "content": message})
44
+
45
+ # ํ† ํฐํ™”
46
+ input_ids = self.tokenizer.apply_chat_template(
47
+ conversation,
48
+ return_tensors="pt",
49
+ add_generation_prompt=True
50
+ )
51
+
52
+ if torch.cuda.is_available():
53
+ input_ids = input_ids.to("cuda")
54
+
55
+ # ์ƒ์„ฑ
56
+ with torch.no_grad():
57
+ outputs = self.model.generate(
58
+ input_ids,
59
+ max_new_tokens=max_tokens,
60
+ temperature=temperature,
61
+ do_sample=True,
62
+ pad_token_id=self.tokenizer.eos_token_id,
63
+ eos_token_id=self.tokenizer.eos_token_id
64
+ )
65
+
66
+ response = self.tokenizer.decode(
67
+ outputs[0][input_ids.shape[-1]:],
68
+ skip_special_tokens=True
69
+ )
70
+
71
+ return response
72
+
73
+ except Exception as e:
74
+ return f"์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"
75
+ finally:
76
+ # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
77
+ if torch.cuda.is_available():
78
+ torch.cuda.empty_cache()
79
+ gc.collect()
80
 
81
+ # ๋ชจ๋ธ ๋งค๋‹ˆ์ € ์ธ์Šคํ„ด์Šค
82
+ model_manager = ModelManager()
83
 
84
+ def chat_fn(message, history, max_tokens, temperature):
85
+ if not message.strip():
86
+ return history, ""
87
+
88
+ # ์‚ฌ์šฉ์ž ๋ฉ”์‹œ์ง€ ์ถ”๊ฐ€
89
+ history.append([message, "์ƒ์„ฑ ์ค‘..."])
90
+
91
+ # ๋ด‡ ์‘๋‹ต ์ƒ์„ฑ
92
+ response = model_manager.generate(message, history[:-1], max_tokens, temperature)
93
+ history[-1][1] = response
94
+
95
+ return history, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค
98
+ with gr.Blocks(title="Command R+ Chat") as demo:
99
+ gr.Markdown("""
100
+ # ๐Ÿค– Command R+ 4bit ์ฑ„ํŒ…๋ด‡
101
+
102
+ Cohere์˜ Command R+ 4bit ์–‘์žํ™” ๋ชจ๋ธ๊ณผ ๋Œ€ํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
103
+ โš ๏ธ ์ฒซ ์‹คํ–‰ ์‹œ ๋ชจ๋ธ ๋กœ๋”ฉ์— ์‹œ๊ฐ„์ด ๊ฑธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
104
+ """)
105
+
106
+ chatbot = gr.Chatbot(
107
+ height=500,
108
+ show_label=False,
109
+ show_copy_button=True
110
+ )
111
+
112
+ with gr.Row():
113
+ msg = gr.Textbox(
114
+ label="๋ฉ”์‹œ์ง€ ์ž…๋ ฅ",
115
+ placeholder="Command R+์—๊ฒŒ ์งˆ๋ฌธํ•˜์„ธ์š”...",
116
+ lines=2,
117
+ scale=4
118
+ )
119
+ submit = gr.Button("์ „์†ก ๐Ÿ“ค", variant="primary", scale=1)
120
+
121
+ with gr.Row():
122
+ clear = gr.Button("๋Œ€ํ™” ์ดˆ๊ธฐํ™” ๐Ÿ—‘๏ธ")
123
+
124
+ with gr.Accordion("๊ณ ๊ธ‰ ์„ค์ •", open=False):
125
+ max_tokens = gr.Slider(
126
+ minimum=100,
127
+ maximum=2000,
128
+ value=1000,
129
+ step=100,
130
+ label="์ตœ๋Œ€ ํ† ํฐ ์ˆ˜"
131
+ )
132
+ temperature = gr.Slider(
133
+ minimum=0.1,
134
+ maximum=1.0,
135
+ value=0.7,
136
+ step=0.1,
137
+ label="Temperature (์ฐฝ์˜์„ฑ)"
138
+ )
139
+
140
+ # ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ
141
+ msg.submit(
142
+ chat_fn,
143
+ [msg, chatbot, max_tokens, temperature],
144
+ [chatbot, msg]
145
+ )
146
+
147
+ submit.click(
148
+ chat_fn,
149
+ [msg, chatbot, max_tokens, temperature],
150
+ [chatbot, msg]
151
+ )
152
+
153
+ clear.click(lambda: ([], ""), outputs=[chatbot, msg])
154
 
155
  if __name__ == "__main__":
156
+ demo.launch()