pabloce commited on
Commit
fa52871
·
verified ·
1 Parent(s): adb6844

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -20
app.py CHANGED
@@ -16,29 +16,42 @@ def respond(
16
  temperature,
17
  top_p,
18
  ):
19
- messages = [{"role": "system", "content": system_message}]
20
 
21
- for val in history:
22
- if val[0]:
23
- messages.append({"role": "user", "content": val[0]})
24
- if val[1]:
25
- messages.append({"role": "assistant", "content": val[1]})
26
-
27
- messages.append({"role": "user", "content": message})
28
-
29
- response = ""
 
 
30
 
31
- for message in client.chat_completion(
32
- messages,
33
- max_tokens=max_tokens,
34
- stream=True,
35
- temperature=temperature,
 
 
 
 
36
  top_p=top_p,
37
- ):
38
- token = message.choices[0].delta.content
39
-
40
- response += token
41
- yield response
 
 
 
 
 
 
 
42
 
43
  """
44
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
 
16
  temperature,
17
  top_p,
18
  ):
19
+ torch.set_default_device("cuda")
20
 
21
+ tokenizer = AutoTokenizer.from_pretrained(
22
+ "cognitivecomputations/dolphin-2.8-mistral-7b-v02",
23
+ trust_remote_code=True
24
+ )
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ "cognitivecomputations/dolphin-2.8-mistral-7b-v02",
27
+ torch_dtype="auto",
28
+ load_in_4bit=True,
29
+ trust_remote_code=True
30
+ )
31
+ history_transformer_format = history + [[message, ""]]
32
 
33
+ system_prompt = f"<|im_start|>system\n{system_message}.<|im_end|>"
34
+ messages = system_prompt + "".join(["".join(["\n<|im_start|>user\n" + item[0], "<|im_end|>\n<|im_start|>assistant\n" + item[1]]) for item in history_transformer_format])
35
+ input_ids = tokenizer([messages], return_tensors="pt").to('cuda')
36
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
37
+ generate_kwargs = dict(
38
+ input_ids,
39
+ streamer=streamer,
40
+ max_new_tokens=max_tokens,
41
+ do_sample=True,
42
  top_p=top_p,
43
+ top_k=50,
44
+ temperature=temperature,
45
+ num_beams=1
46
+ )
47
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
48
+ t.start()
49
+ partial_message = ""
50
+ for new_token in streamer:
51
+ partial_message += new_token
52
+ if '<|im_end|>' in partial_message:
53
+ break
54
+ yield partial_message
55
 
56
  """
57
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface