Spaces:
Runtime error
Runtime error
File size: 5,976 Bytes
1a94aaa 1511e8d 1a94aaa 7212a44 d796187 7212a44 1a94aaa adca78b 1a94aaa 1511e8d 1a94aaa e9d0c91 1a94aaa 1511e8d 9496358 1511e8d 857b70a 1511e8d 1a94aaa c2d580a 1a94aaa 9496358 1a94aaa 9496358 1a94aaa e9d0c91 f49f61b 7212a44 1a94aaa 1511e8d 1a94aaa 9496358 1a94aaa 7212a44 1a94aaa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import json
import gradio as gr
import os
import spacy
spacy.cli.download('en_core_web_sm')
nlp = spacy.load('en_core_web_sm')
import nltk
nltk.download('stopwords')
nltk.download('punkt')
from rake_nltk import Rake
r = Rake()
import time
import wikipediaapi
wiki_wiki = wikipediaapi.Wikipedia('Organika ([email protected])', 'en')
## ctransformers disabled for now
# from ctransformers import AutoModelForCausalLM
# model = AutoModelForCausalLM.from_pretrained(
# "Colby/StarCoder-3B-WoW-JSON",
# model_file="StarCoder-3B-WoW-JSON-ggml.bin",
# model_type="gpt_bigcode"
# )
# Use a pipeline as a high-level helper
from transformers import pipeline
topic_model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9")
#model = pipeline("text-generation", model="Colby/StarCoder-3B-WoW-JSON", device=0)
import requests
# function for Huggingface API calls
def query(payload, model_path, headers):
API_URL = "https://api-inference.huggingface.co/models/" + model_path
for retry in range(3):
response = requests.post(API_URL, headers=headers, json=payload)
if response.status_code == requests.codes.ok:
try:
results = response.json()
return results
except:
print('Invalid response received from server')
print(response)
return None
else:
# Not connected to internet maybe?
if response.status_code==404:
print('Are you connected to the internet?')
print('URL attempted = '+API_URL)
break
if response.status_code==503:
print(response.json())
continue
if response.status_code==504:
print('504 Gateway Timeout')
else:
print('Unsuccessful request, status code '+ str(response.status_code))
# print(response.json()) #debug only
print(payload)
def generate_text(prompt, model_path, text_generation_parameters, headers):
start_time = time.time()
options = {'use_cache': False, 'wait_for_model': True}
payload = {"inputs": prompt, "parameters": text_generation_parameters, "options": options}
output_list = query(payload, model_path, headers)
if not output_list:
print('Generation failed')
end_time = time.time()
duration = round(end_time - start_time, 1)
stringlist = []
if output_list and 'generated_text' in output_list[0].keys():
print(f'{len(output_list)} sample(s) of text generated in {duration} seconds.')
for gendict in output_list:
stringlist.append(gendict['generated_text'])
else:
print(output_list)
return(stringlist)
model_path = "Colby/StarCoder-1B-WoW-JSON"
parameters = {
"max_new_tokens": 250,
"return_full_text": False,
"do_sample": True,
"temperature": 0.8,
"top_p": 0.9,
"top_k": 50,
"repetition_penalty": 1.1
}
headers = {"Authorization": "Bearer " + os.environ['HF_TOKEN']}
def merlin_chat(message, history):
chat_text = ""
chat_json = ""
for turn in history:
chat_text += f"USER: {turn[0]}\n\nASSISTANT: {turn[1]}\n\n"
chat_json += json.dumps({"role": "user", "content": turn[0]})
chat_json += json.dumps({"role": "assistant", "content": turn[1]})
chat_text += f"USER: {message}\n"
doc = nlp(chat_text)
ents_found = []
if doc.ents:
for ent in doc.ents:
if len(ents_found) == 3:
break
if ent.text.isnumeric() or ent.label in ["DATE","TIME","PERCENT","MONEY","QUANTITY","ORDINAL","CARDINAL"]:
continue
if ent.text in ents_found:
continue
ents_found.append(ent.text.title())
r.extract_keywords_from_text(chat_text)
ents_found = ents_found + r.get_ranked_phrases()[:3]
context = ""
scores = topic_model(chat_text, ents_found, multi_label=True)['scores']
if ents_found:
max_score = 0
for k in range(len(ents_found)):
if scores[k] < 0.5:
continue
entity = ents_found[k]
if scores[k] > max_score:
max_score = scores[k]
max_topic = entity
print(f'# Looking up {entity} on Wikipedia... ', end='')
wiki_page = wiki_wiki.page(entity)
if wiki_page.exists():
print("page found... ")
entsum = wiki_page.summary
if "may refer to" in entsum or "may also refer to" in entsum:
print(" ambiguous, skipping.")
continue
else:
context += entsum + '\n\n'
context
system_msg = {
'role': 'system', 'content': context
}
user_msg = {'role': 'user', 'content': message}
prompt = "[" + json.dumps(system_msg) + chat_json + json.dumps(user_msg) + "{'role': 'assistant, 'content': '"
for attempt in range(3):
# result = model(prompt, max_new_tokens=250, return_full_text=False, handle_long_generation="hole")
result = generate_text(prompt, model_path, parameters, headers)
response = result[0]['generated_text']
start = 0
end = 0
cleanStr = response.lstrip()
# cleanStr = cleanStr.replace(prompt,"")
start = cleanStr.find('{') # this should skip over whatever it recalls to what it says next
if start<=0:
continue
cleanStr = cleanStr[start:]
end = cleanStr.find('}') + 1
if end<=0:
continue
cleanStr = cleanStr[:end]
messageStr = prompt + cleanStr + "]"
messages = json.loads(cleanStr)
message = messages[-1]
if message['role'] != 'assistant':
continue
return message['content']
return "🤔"
gr.ChatInterface(merlin_chat).launch() |