beingcognitive commited on
Commit
1d7e7b8
Β·
1 Parent(s): ddf360b

gemma-2-2b-it

Browse files
Files changed (1) hide show
  1. app.py +40 -15
app.py CHANGED
@@ -2,39 +2,63 @@ import os
2
  from datetime import datetime
3
  import uuid
4
  import gradio as gr
5
- from transformers import AutoTokenizer, AutoModelForCausalLM
6
  import torch
7
  from huggingface_hub import login
 
8
 
9
  from dotenv import load_dotenv
 
10
  # Load environment variables
11
  load_dotenv()
12
 
13
- # Authenticate with Hugging Face
14
- login(token=os.getenv("HUGGINGFACE_TOKEN"))
15
 
16
  # Load model and tokenizer
17
- model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
18
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
19
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto", use_auth_token=True)
 
 
 
 
 
 
20
 
21
  def chat_with_model(messages):
22
- input_ids = tokenizer.encode(str(messages), return_tensors="pt").to(model.device)
23
- output = model.generate(input_ids, max_length=1000, num_return_sequences=1, temperature=0.7)
24
- response = tokenizer.decode(output[0], skip_special_tokens=True)
25
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def chat_with_model_gradio(message, history, session_id):
28
  messages = [
29
  {"role": "system", "content": f"λ„ˆμ˜ 이름은 ChatMBTI. μ‚¬λžŒλ“€μ˜ MBTIμœ ν˜•μ— μ•Œλ§žμ€ 상담을 μ§„ν–‰ν•  수 μžˆμ–΄. μƒλŒ€λ°©μ˜ MBTI μœ ν˜•μ„ λ¨Όμ € 물어보고, κ·Έ μœ ν˜•μ— μ•Œλ§žκ²Œ 상담을 μ§„ν–‰ν•΄μ€˜. 참고둜 ν˜„μž¬ μ‹œκ°μ€ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}이야."},
30
  ]
31
- messages.extend([{"role": "user" if i % 2 == 0 else "assistant", "content": m} for i, m in enumerate(history)])
32
  messages.append({"role": "user", "content": message})
33
 
34
- response = chat_with_model(messages)
35
- history.append((message, response))
36
-
37
- return "", history
 
 
38
 
39
  def main():
40
  session_id = str(uuid.uuid4())
@@ -46,6 +70,7 @@ def main():
46
  msg.submit(chat_with_model_gradio, [msg, chatbot, gr.State(session_id)], [msg, chatbot])
47
  clear.click(lambda: None, None, chatbot, queue=False)
48
 
 
49
  demo.launch()
50
 
51
  if __name__ == "__main__":
 
2
  from datetime import datetime
3
  import uuid
4
  import gradio as gr
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
6
  import torch
7
  from huggingface_hub import login
8
+ from threading import Thread
9
 
10
  from dotenv import load_dotenv
11
+
12
  # Load environment variables
13
  load_dotenv()
14
 
15
+ # Get the Hugging Face token from environment variables
16
+ hf_token = os.getenv("HUGGINGFACE_TOKEN")
17
 
18
  # Load model and tokenizer
19
+ model_name = "google/gemma-2-2b-it"
20
+
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ model_name,
24
+ torch_dtype=torch.float16,
25
+ device_map="auto",
26
+ token=hf_token
27
+ )
28
 
29
  def chat_with_model(messages):
30
+ # Prepare the input
31
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
32
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
33
+
34
+ # Generate response
35
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
36
+ generation_kwargs = dict(
37
+ inputs,
38
+ max_new_tokens=1000,
39
+ temperature=0.7,
40
+ do_sample=True,
41
+ streamer=streamer,
42
+ )
43
+
44
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
45
+ thread.start()
46
+
47
+ return streamer
48
 
49
  def chat_with_model_gradio(message, history, session_id):
50
  messages = [
51
  {"role": "system", "content": f"λ„ˆμ˜ 이름은 ChatMBTI. μ‚¬λžŒλ“€μ˜ MBTIμœ ν˜•μ— μ•Œλ§žμ€ 상담을 μ§„ν–‰ν•  수 μžˆμ–΄. μƒλŒ€λ°©μ˜ MBTI μœ ν˜•μ„ λ¨Όμ € 물어보고, κ·Έ μœ ν˜•μ— μ•Œλ§žκ²Œ 상담을 μ§„ν–‰ν•΄μ€˜. 참고둜 ν˜„μž¬ μ‹œκ°μ€ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}이야."},
52
  ]
53
+ messages.extend([{"role": "user" if i % 2 == 0 else "model", "content": m} for i, (m, _) in enumerate(history)])
54
  messages.append({"role": "user", "content": message})
55
 
56
+ streamer = chat_with_model(messages)
57
+
58
+ partial_message = ""
59
+ for new_token in streamer:
60
+ partial_message += new_token
61
+ yield "", history + [(message, partial_message)]
62
 
63
  def main():
64
  session_id = str(uuid.uuid4())
 
70
  msg.submit(chat_with_model_gradio, [msg, chatbot, gr.State(session_id)], [msg, chatbot])
71
  clear.click(lambda: None, None, chatbot, queue=False)
72
 
73
+ demo.queue()
74
  demo.launch()
75
 
76
  if __name__ == "__main__":