import json import gradio as gr import os import spacy spacy.cli.download('en_core_web_sm') nlp = spacy.load('en_core_web_sm') import nltk nltk.download('stopwords') nltk.download('punkt') from rake_nltk import Rake r = Rake() import time import wikipediaapi wiki_wiki = wikipediaapi.Wikipedia('Organika (cmatthewbrown@gmail.com)', 'en') ## ctransformers disabled for now # from ctransformers import AutoModelForCausalLM # model = AutoModelForCausalLM.from_pretrained( # "Colby/StarCoder-3B-WoW-JSON", # model_file="StarCoder-3B-WoW-JSON-ggml.bin", # model_type="gpt_bigcode" # ) # Use a pipeline as a high-level helper from transformers import pipeline topic_model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9") #model = pipeline("text-generation", model="Colby/StarCoder-3B-WoW-JSON", device=0) import requests # function for Huggingface API calls def query(payload, model_path, headers): API_URL = "https://api-inference.huggingface.co/models/" + model_path for retry in range(3): response = requests.post(API_URL, headers=headers, json=payload) if response.status_code == requests.codes.ok: try: results = response.json() return results except: print('Invalid response received from server') print(response) return None else: # Not connected to internet maybe? if response.status_code==404: print('Are you connected to the internet?') print('URL attempted = '+API_URL) break if response.status_code==503: print(response.json()) continue if response.status_code==504: print('504 Gateway Timeout') else: print('Unsuccessful request, status code '+ str(response.status_code)) # print(response.json()) #debug only print(payload) def generate_text(prompt, model_path, text_generation_parameters, headers): start_time = time.time() options = {'use_cache': False, 'wait_for_model': True} payload = {"inputs": prompt, "parameters": text_generation_parameters, "options": options} output_list = query(payload, model_path, headers) if not output_list: print('Generation failed') end_time = time.time() duration = round(end_time - start_time, 1) stringlist = [] if output_list and 'generated_text' in output_list[0].keys(): print(f'{len(output_list)} sample(s) of text generated in {duration} seconds.') for gendict in output_list: stringlist.append(gendict['generated_text']) else: print(output_list) return(stringlist) model_path = "Colby/StarCoder-3B-WoW-JSON" parameters = { "max_new_tokens": 250, "return_full_text": False, "do_sample": True, "temperature": 0.8, "top_p": 0.9, "top_k": 50, "repetition_penalty": 1.1 } headers = "Bearer " + os.environ['HF_TOKEN'] def merlin_chat(message, history): chat_text = "" chat_json = "" for turn in history: chat_text += f"USER: {turn[0]}\n\nASSISTANT: {turn[1]}\n\n" chat_json += json.dumps({"role": "user", "content": turn[0]}) chat_json += json.dumps({"role": "assistant", "content": turn[1]}) chat_text += f"USER: {message}\n" doc = nlp(chat_text) ents_found = [] if doc.ents: for ent in doc.ents: if len(ents_found) == 3: break if ent.text.isnumeric() or ent.label in ["DATE","TIME","PERCENT","MONEY","QUANTITY","ORDINAL","CARDINAL"]: continue if ent.text in ents_found: continue ents_found.append(ent.text.title()) r.extract_keywords_from_text(chat_text) ents_found = ents_found + r.get_ranked_phrases()[:3] context = "" scores = topic_model(chat_text, ents_found, multi_label=True)['scores'] if ents_found: max_score = 0 for k in range(len(ents_found)): if scores[k] < 0.5: continue entity = ents_found[k] if scores[k] > max_score: max_score = scores[k] max_topic = entity print(f'# Looking up {entity} on Wikipedia... ', end='') wiki_page = wiki_wiki.page(entity) if wiki_page.exists(): print("page found... ") entsum = wiki_page.summary if "may refer to" in entsum or "may also refer to" in entsum: print(" ambiguous, skipping.") continue else: context += entsum + '\n\n' context system_msg = { 'role': 'system', 'content': context + f'\n\nThe following is a conversation about {max_topic}.' } user_msg = {'role': 'user', 'content': message} prompt = "[" + json.dumps(system_msg) + chat_json + json.dumps(user_msg) + "{'role': 'assistant, 'content': '*recalls \"" for attempt in range(3): # result = model(prompt, max_new_tokens=250, return_full_text=False, handle_long_generation="hole") result = generate_text(prompt, model_path, parameters, headers) response = result[0]['generated_text'] start = 0 end = 0 cleanStr = response.lstrip() # cleanStr = cleanStr.replace(prompt,"") start = cleanStr.find('{') # this should skip over whatever it recalls to what it says next if start<=0: continue cleanStr = cleanStr[start:] end = cleanStr.find('}') + 1 if end<=0: continue cleanStr = cleanStr[:end] message = json.loads(cleanStr) if message['role'] != 'assistant': continue return message['content'] return "🤔" gr.ChatInterface(merlin_chat).launch()