mahiatlinux commited on
Commit
7c779c2
·
verified ·
1 Parent(s): 27319be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  from threading import Thread
3
  from typing import Iterator
@@ -24,7 +25,7 @@ if not torch.cuda.is_available():
24
 
25
 
26
  if torch.cuda.is_available():
27
- model_id = "Nexusflow/Starling-LM-7B-beta"
28
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
29
  tokenizer = AutoTokenizer.from_pretrained(model_id)
30
  tokenizer.use_default_system_prompt = False
@@ -43,10 +44,10 @@ def generate(
43
  ) -> Iterator[str]:
44
  conversation = []
45
  if system_prompt:
46
- conversation.append({"role": "system", "content": system_prompt})
47
  for user, assistant in chat_history:
48
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
49
- conversation.append({"role": "user", "content": message})
50
 
51
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True)
52
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
 
1
+ {"from": "human", "value": "who are you"}
2
  import os
3
  from threading import Thread
4
  from typing import Iterator
 
25
 
26
 
27
  if torch.cuda.is_available():
28
+ model_id = "mahiatlinux/MasherAI-v6-7B"
29
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
30
  tokenizer = AutoTokenizer.from_pretrained(model_id)
31
  tokenizer.use_default_system_prompt = False
 
44
  ) -> Iterator[str]:
45
  conversation = []
46
  if system_prompt:
47
+ conversation.append({"from": "human", "value": "You are an AI assistant."})
48
  for user, assistant in chat_history:
49
+ conversation.extend([{"from": "human", "value": user}, {"from": "gpt", "value": assistant}])
50
+ conversation.append({"from": "human", "value": message})
51
 
52
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True)
53
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: