umm-maybe commited on
Commit
03c3dac
·
verified ·
1 Parent(s): b97fa02

Update app.py

Browse files

Upgrade model to 7B version

Files changed (1) hide show
  1. app.py +31 -3
app.py CHANGED
@@ -30,7 +30,34 @@ wiki_wiki = wikipediaapi.Wikipedia('Organika ([email protected])', 'en')
30
  # Use a pipeline as a high-level helper
31
  from transformers import pipeline
32
  topic_model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9", device=0)
33
- model = pipeline("text-generation", model="Organika/StarCoder-1B-WoW-JSON", device=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def merlin_chat(message, history):
36
  chat_text = ""
@@ -89,7 +116,8 @@ def merlin_chat(message, history):
89
  print(f"PROMPT: {prompt}")
90
  for attempt in range(3):
91
  #result = generate_text(prompt, model_path, parameters, headers)
92
- result = model(prompt,return_full_text=False, max_new_tokens=256, temperature=0.8, repetition_penalty=1.1)
 
93
  response = result[0]['generated_text']
94
  print(f"COMPLETION: {response}") # so we can see it in logs
95
  start = 0
@@ -102,7 +130,7 @@ def merlin_chat(message, history):
102
  if end<=0:
103
  continue
104
  cleanStr = cleanStr[:end]
105
- messageStr = prompt + cleanStr + ']'
106
  messages = json.loads(messageStr)
107
  message = messages[-1]
108
  if message['role'] != 'assistant':
 
30
  # Use a pipeline as a high-level helper
31
  from transformers import pipeline
32
  topic_model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9", device=0)
33
+ #model = pipeline("text-generation", model="Organika/StarCoder-7B-WoW-JSON_1", device=0)
34
+
35
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
36
+
37
+ bnb_config = BitsAndBytesConfig(
38
+ load_in_4bit=True,
39
+ bnb_4bit_use_double_quant=True,
40
+ bnb_4bit_quant_type="nf4",
41
+ bnb_4bit_compute_dtype=torch.bfloat16
42
+ )
43
+
44
+ model_name = "umm-maybe/StarCoder-7B-WoW-JSON_1"
45
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
46
+ model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, device_map="auto")
47
+
48
+ def generate_text(prompt):
49
+ inputs = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
50
+ outputs = model.generate(
51
+ inputs,
52
+ do_sample=True,
53
+ max_new_tokens=200,
54
+ temperature=0.6,
55
+ top_p=0.9,
56
+ top_k=40,
57
+ repetition_penalty=1.1
58
+ )
59
+ results = self.tokenizer.decode(outputs[0], clean_up_tokenization_spaces=False) #.replace(prompt,"")
60
+ return results
61
 
62
  def merlin_chat(message, history):
63
  chat_text = ""
 
116
  print(f"PROMPT: {prompt}")
117
  for attempt in range(3):
118
  #result = generate_text(prompt, model_path, parameters, headers)
119
+ #result = model(prompt,return_full_text=False, max_new_tokens=256, temperature=0.8, repetition_penalty=1.1)
120
+ result = generate_text(prompt)
121
  response = result[0]['generated_text']
122
  print(f"COMPLETION: {response}") # so we can see it in logs
123
  start = 0
 
130
  if end<=0:
131
  continue
132
  cleanStr = cleanStr[:end]
133
+ messageStr = cleanStr + ']'
134
  messages = json.loads(messageStr)
135
  message = messages[-1]
136
  if message['role'] != 'assistant':