Colby's picture
Update app.py
9496358 verified
raw
history blame
5.98 kB
import json
import gradio as gr
import os
import spacy
spacy.cli.download('en_core_web_sm')
nlp = spacy.load('en_core_web_sm')
import nltk
nltk.download('stopwords')
nltk.download('punkt')
from rake_nltk import Rake
r = Rake()
import time
import wikipediaapi
wiki_wiki = wikipediaapi.Wikipedia('Organika ([email protected])', 'en')
## ctransformers disabled for now
# from ctransformers import AutoModelForCausalLM
# model = AutoModelForCausalLM.from_pretrained(
# "Colby/StarCoder-3B-WoW-JSON",
# model_file="StarCoder-3B-WoW-JSON-ggml.bin",
# model_type="gpt_bigcode"
# )
# Use a pipeline as a high-level helper
from transformers import pipeline
topic_model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9")
#model = pipeline("text-generation", model="Colby/StarCoder-3B-WoW-JSON", device=0)
import requests
# function for Huggingface API calls
def query(payload, model_path, headers):
API_URL = "https://api-inference.huggingface.co/models/" + model_path
for retry in range(3):
response = requests.post(API_URL, headers=headers, json=payload)
if response.status_code == requests.codes.ok:
try:
results = response.json()
return results
except:
print('Invalid response received from server')
print(response)
return None
else:
# Not connected to internet maybe?
if response.status_code==404:
print('Are you connected to the internet?')
print('URL attempted = '+API_URL)
break
if response.status_code==503:
print(response.json())
continue
if response.status_code==504:
print('504 Gateway Timeout')
else:
print('Unsuccessful request, status code '+ str(response.status_code))
# print(response.json()) #debug only
print(payload)
def generate_text(prompt, model_path, text_generation_parameters, headers):
start_time = time.time()
options = {'use_cache': False, 'wait_for_model': True}
payload = {"inputs": prompt, "parameters": text_generation_parameters, "options": options}
output_list = query(payload, model_path, headers)
if not output_list:
print('Generation failed')
end_time = time.time()
duration = round(end_time - start_time, 1)
stringlist = []
if output_list and 'generated_text' in output_list[0].keys():
print(f'{len(output_list)} sample(s) of text generated in {duration} seconds.')
for gendict in output_list:
stringlist.append(gendict['generated_text'])
else:
print(output_list)
return(stringlist)
model_path = "Colby/StarCoder-1B-WoW-JSON"
parameters = {
"max_new_tokens": 250,
"return_full_text": False,
"do_sample": True,
"temperature": 0.8,
"top_p": 0.9,
"top_k": 50,
"repetition_penalty": 1.1
}
headers = {"Authorization": "Bearer " + os.environ['HF_TOKEN']}
def merlin_chat(message, history):
chat_text = ""
chat_json = ""
for turn in history:
chat_text += f"USER: {turn[0]}\n\nASSISTANT: {turn[1]}\n\n"
chat_json += json.dumps({"role": "user", "content": turn[0]})
chat_json += json.dumps({"role": "assistant", "content": turn[1]})
chat_text += f"USER: {message}\n"
doc = nlp(chat_text)
ents_found = []
if doc.ents:
for ent in doc.ents:
if len(ents_found) == 3:
break
if ent.text.isnumeric() or ent.label in ["DATE","TIME","PERCENT","MONEY","QUANTITY","ORDINAL","CARDINAL"]:
continue
if ent.text in ents_found:
continue
ents_found.append(ent.text.title())
r.extract_keywords_from_text(chat_text)
ents_found = ents_found + r.get_ranked_phrases()[:3]
context = ""
scores = topic_model(chat_text, ents_found, multi_label=True)['scores']
if ents_found:
max_score = 0
for k in range(len(ents_found)):
if scores[k] < 0.5:
continue
entity = ents_found[k]
if scores[k] > max_score:
max_score = scores[k]
max_topic = entity
print(f'# Looking up {entity} on Wikipedia... ', end='')
wiki_page = wiki_wiki.page(entity)
if wiki_page.exists():
print("page found... ")
entsum = wiki_page.summary
if "may refer to" in entsum or "may also refer to" in entsum:
print(" ambiguous, skipping.")
continue
else:
context += entsum + '\n\n'
context
system_msg = {
'role': 'system', 'content': context
}
user_msg = {'role': 'user', 'content': message}
prompt = "[" + json.dumps(system_msg) + chat_json + json.dumps(user_msg) + "{'role': 'assistant, 'content': '"
for attempt in range(3):
# result = model(prompt, max_new_tokens=250, return_full_text=False, handle_long_generation="hole")
result = generate_text(prompt, model_path, parameters, headers)
response = result[0]['generated_text']
start = 0
end = 0
cleanStr = response.lstrip()
# cleanStr = cleanStr.replace(prompt,"")
start = cleanStr.find('{') # this should skip over whatever it recalls to what it says next
if start<=0:
continue
cleanStr = cleanStr[start:]
end = cleanStr.find('}') + 1
if end<=0:
continue
cleanStr = cleanStr[:end]
messageStr = prompt + cleanStr + "]"
messages = json.loads(cleanStr)
message = messages[-1]
if message['role'] != 'assistant':
continue
return message['content']
return "🤔"
gr.ChatInterface(merlin_chat).launch()