Spaces:
Runtime error
Runtime error
File size: 4,905 Bytes
1a94aaa 1511e8d 1a94aaa 7212a44 d796187 7212a44 1a94aaa adca78b 1a94aaa 1511e8d 1a94aaa 03c5aaf 03c3dac bea12d1 03c3dac d9e1efb 03c3dac 1511e8d 1a94aaa 3b513a4 03c5aaf 714170e 3b513a4 e99b177 1a94aaa 03c5aaf 1a94aaa 03c5aaf 1a94aaa c2d580a 1a94aaa 03c5aaf 1a94aaa 9496358 1a94aaa 3b513a4 1a94aaa 3b513a4 348444e 03c5aaf 1a94aaa 03c5aaf 03c3dac 82151fa d9e1efb 03c5aaf 1a94aaa 348444e 54033c4 1a94aaa 348444e d9e1efb 54033c4 03c5aaf 54033c4 7212a44 1a94aaa b97fa02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import json
import gradio as gr
import os
import spacy
spacy.cli.download('en_core_web_sm')
nlp = spacy.load('en_core_web_sm')
import nltk
nltk.download('stopwords')
nltk.download('punkt')
from rake_nltk import Rake
r = Rake()
import time
import wikipediaapi
wiki_wiki = wikipediaapi.Wikipedia('Organika ([email protected])', 'en')
## ctransformers disabled for now
# from ctransformers import AutoModelForCausalLM
# model = AutoModelForCausalLM.from_pretrained(
# "Colby/StarCoder-3B-WoW-JSON",
# model_file="StarCoder-3B-WoW-JSON-ggml.bin",
# model_type="gpt_bigcode"
# )
# Use a pipeline as a high-level helper
from transformers import pipeline
topic_model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9", device=0)
#model = pipeline("text-generation", model="Organika/StarCoder-7B-WoW-JSON_1", device=0)
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model_name = "umm-maybe/StarCoder-7B-WoW-JSON_1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, device_map="auto")
def generate_text(prompt):
inputs = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(
inputs,
do_sample=True,
max_new_tokens=200,
temperature=0.6,
top_p=0.9,
top_k=40,
repetition_penalty=1.1
)
results = tokenizer.decode(outputs[0], clean_up_tokenization_spaces=False)
return results
def merlin_chat(message, history):
chat_text = ""
chat_list = []
for turn in history[-3:]:
chat_text += f"{turn[0]}\n\n{turn[1]}\n\n"
chat_list.append({"role": "user", "content": turn[0]})
chat_list.append({"role": "assistant", "content": turn[1]})
chat_text += f"{message}\n"
doc = nlp(chat_text)
ents_found = []
if doc.ents:
for ent in doc.ents:
if len(ents_found) == 3:
break
if ent.text.isnumeric() or ent.label in ["DATE","TIME","PERCENT","MONEY","QUANTITY","ORDINAL","CARDINAL"]:
continue
if ent.text in ents_found:
continue
ents_found.append(ent.text.title().lower())
r.extract_keywords_from_text(chat_text)
for phrase in r.get_ranked_phrases()[:3]:
phrase = phrase.lower()
if phrase not in ents_found:
ents_found.append(phrase)
context = ""
scores = topic_model(chat_text, ents_found, multi_label=True)['scores']
if ents_found:
max_score = 0
for k in range(len(ents_found)):
if scores[k] < 0.5:
continue
entity = ents_found[k]
if scores[k] > max_score:
max_score = scores[k]
max_topic = entity
print(f'# Looking up {entity} on Wikipedia... ', end='')
wiki_page = wiki_wiki.page(entity)
if wiki_page.exists():
print("page found... ")
entsum = wiki_page.summary
if "may refer to" in entsum or "may also refer to" in entsum:
print(" ambiguous, skipping.")
continue
else:
context += entsum + '\n\n'
else:
print("not found.")
system_msg = {
'role': 'system', 'content': context
}
chat_list.insert(0,system_msg)
user_msg = {'role': 'user', 'content': message}
chat_list.append(user_msg)
prompt = json.dumps(chat_list)[:-1] + ",{\"role\": \"assistant\", \"content\": \""
print(f"PROMPT: {prompt}")
for attempt in range(3):
#result = generate_text(prompt, model_path, parameters, headers)
#result = model(prompt,return_full_text=False, max_new_tokens=256, temperature=0.8, repetition_penalty=1.1)
#response = result[0]['generated_text']
result = generate_text(prompt)
response = result.replace(prompt,"")
print(f"COMPLETION: {response}") # so we can see it in logs
start = 0
end = 0
cleanStr = response.lstrip()
#start = cleanStr.find('{')
#if start<=0:
# continue
end = cleanStr.find('}') + 1
if end<=0:
continue
cleanStr = cleanStr[:end]
messageStr = prompt + cleanStr + ']'
messages = json.loads(messageStr)
message = messages[-1]
if message['role'] != 'assistant':
continue
msg_text = message['content']
if chat_text.find(msg_text) >= 0:
continue
return message['content']
return "🤔"
gr.ChatInterface(merlin_chat).launch() |