Spaces:
Runtime error
Runtime error
import json | |
import gradio as gr | |
import os | |
import spacy | |
spacy.cli.download('en_core_web_sm') | |
nlp = spacy.load('en_core_web_sm') | |
import nltk | |
nltk.download('stopwords') | |
nltk.download('punkt') | |
from rake_nltk import Rake | |
r = Rake() | |
import time | |
import wikipediaapi | |
wiki_wiki = wikipediaapi.Wikipedia('Organika ([email protected])', 'en') | |
## ctransformers disabled for now | |
# from ctransformers import AutoModelForCausalLM | |
# model = AutoModelForCausalLM.from_pretrained( | |
# "Colby/StarCoder-3B-WoW-JSON", | |
# model_file="StarCoder-3B-WoW-JSON-ggml.bin", | |
# model_type="gpt_bigcode" | |
# ) | |
# Use a pipeline as a high-level helper | |
from transformers import pipeline | |
topic_model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9") | |
#model = pipeline("text-generation", model="Colby/StarCoder-3B-WoW-JSON", device=0) | |
import requests | |
# function for Huggingface API calls | |
def query(payload, model_path, headers): | |
API_URL = "https://api-inference.huggingface.co/models/" + model_path | |
for retry in range(3): | |
response = requests.post(API_URL, headers=headers, json=payload) | |
if response.status_code == requests.codes.ok: | |
try: | |
results = response.json() | |
return results | |
except: | |
print('Invalid response received from server') | |
print(response) | |
return None | |
else: | |
# Not connected to internet maybe? | |
if response.status_code==404: | |
print('Are you connected to the internet?') | |
print('URL attempted = '+API_URL) | |
break | |
if response.status_code==503: | |
print(response.json()) | |
continue | |
if response.status_code==504: | |
print('504 Gateway Timeout') | |
else: | |
print('Unsuccessful request, status code '+ str(response.status_code)) | |
# print(response.json()) #debug only | |
print(payload) | |
def generate_text(prompt, model_path, text_generation_parameters, headers): | |
start_time = time.time() | |
options = {'use_cache': False, 'wait_for_model': True} | |
payload = {"inputs": prompt, "parameters": text_generation_parameters, "options": options} | |
output_list = query(payload, model_path, headers) | |
if not output_list: | |
print('Generation failed') | |
end_time = time.time() | |
duration = round(end_time - start_time, 1) | |
stringlist = [] | |
if output_list and 'generated_text' in output_list[0].keys(): | |
print(f'{len(output_list)} sample(s) of text generated in {duration} seconds.') | |
for gendict in output_list: | |
stringlist.append(gendict['generated_text']) | |
else: | |
print(output_list) | |
return(stringlist) | |
model_path = "Colby/StarCoder-1B-WoW-JSON" | |
parameters = { | |
"max_new_tokens": 250, | |
"return_full_text": False, | |
"do_sample": True, | |
"temperature": 0.8, | |
"top_p": 0.9, | |
"top_k": 50, | |
"repetition_penalty": 1.1 | |
} | |
headers = {"Authorization": "Bearer " + os.environ['HF_TOKEN']} | |
def merlin_chat(message, history): | |
chat_text = "" | |
#chat_json = "" | |
for turn in history: | |
chat_text += f"{turn[0]}\n\n{turn[1]}\n\n" | |
#chat_json += json.dumps({"role": "user", "content": turn[0]}) | |
#chat_json += json.dumps({"role": "assistant", "content": turn[1]}) | |
chat_text += f"USER: {message}\n" | |
doc = nlp(chat_text) | |
ents_found = [] | |
if doc.ents: | |
for ent in doc.ents: | |
if len(ents_found) == 3: | |
break | |
if ent.text.isnumeric() or ent.label in ["DATE","TIME","PERCENT","MONEY","QUANTITY","ORDINAL","CARDINAL"]: | |
continue | |
if ent.text in ents_found: | |
continue | |
ents_found.append(ent.text.title()) | |
r.extract_keywords_from_text(chat_text) | |
ents_found = ents_found + r.get_ranked_phrases()[:3] | |
context = "" | |
scores = topic_model(chat_text, ents_found, multi_label=True)['scores'] | |
if ents_found: | |
max_score = 0 | |
for k in range(len(ents_found)): | |
if scores[k] < 0.5: | |
continue | |
entity = ents_found[k] | |
if scores[k] > max_score: | |
max_score = scores[k] | |
max_topic = entity | |
print(f'# Looking up {entity} on Wikipedia... ', end='') | |
wiki_page = wiki_wiki.page(entity) | |
if wiki_page.exists(): | |
print("page found... ") | |
entsum = wiki_page.summary | |
if "may refer to" in entsum or "may also refer to" in entsum: | |
print(" ambiguous, skipping.") | |
continue | |
else: | |
context += entsum + '\n\n' | |
system_msg = { | |
'role': 'system', 'content': context | |
} | |
user_msg = {'role': 'user', 'content': message} | |
prompt_data = history | |
prompt_data.append(user_msg) | |
prompt_data.insert(0,system_msg) | |
prompt = json.dumps(system_msg)[:-1] + ",{'role': 'assistant', 'content': '" | |
for attempt in range(3): | |
# result = model(prompt, max_new_tokens=250, return_full_text=False, handle_long_generation="hole") | |
result = generate_text(prompt, model_path, parameters, headers) | |
response = result[0] | |
print(response) # so we can see it in logs | |
start = 0 | |
end = 0 | |
cleanStr = response.lstrip() | |
# cleanStr = cleanStr.replace(prompt,"") | |
end = cleanStr.find('}') + 1 | |
if end<=0: | |
continue | |
cleanStr = cleanStr[:end] | |
messageStr = prompt + cleanStr + "]" | |
messages = json.loads(messageStr) | |
message = messages[-1] | |
if message['role'] != 'assistant': | |
continue | |
return message['content'] | |
return "🤔" | |
gr.ChatInterface(merlin_chat).launch() |