File size: 5,976 Bytes
1a94aaa
 
 
1511e8d
1a94aaa
 
 
 
 
7212a44
 
d796187
7212a44
1a94aaa
 
 
adca78b
 
1a94aaa
 
 
 
1511e8d
 
 
 
 
 
1a94aaa
 
 
 
e9d0c91
1a94aaa
1511e8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9496358
1511e8d
 
 
 
 
 
 
 
 
857b70a
1511e8d
1a94aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2d580a
1a94aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9496358
1a94aaa
 
9496358
1a94aaa
e9d0c91
f49f61b
7212a44
1a94aaa
 
 
1511e8d
1a94aaa
 
 
 
 
 
 
 
9496358
 
 
1a94aaa
 
 
7212a44
1a94aaa
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import json

import gradio as gr
import os

import spacy
spacy.cli.download('en_core_web_sm')
nlp = spacy.load('en_core_web_sm')

import nltk
nltk.download('stopwords')
nltk.download('punkt')

from rake_nltk import Rake
r = Rake()

import time

import wikipediaapi
wiki_wiki = wikipediaapi.Wikipedia('Organika ([email protected])', 'en')

## ctransformers disabled for now
# from ctransformers import AutoModelForCausalLM
# model = AutoModelForCausalLM.from_pretrained(
#     "Colby/StarCoder-3B-WoW-JSON", 
#     model_file="StarCoder-3B-WoW-JSON-ggml.bin",
#     model_type="gpt_bigcode"
#    )

# Use a pipeline as a high-level helper
from transformers import pipeline
topic_model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9")
#model = pipeline("text-generation", model="Colby/StarCoder-3B-WoW-JSON", device=0)

import requests

# function for Huggingface API calls
def query(payload, model_path, headers):
    API_URL = "https://api-inference.huggingface.co/models/" + model_path
    for retry in range(3):
        response = requests.post(API_URL, headers=headers, json=payload)
        if response.status_code == requests.codes.ok:
            try:
                results = response.json()
                return results
            except:
                print('Invalid response received from server')
                print(response)
                return None
        else:
            # Not connected to internet maybe?
            if response.status_code==404:
                print('Are you connected to the internet?')
                print('URL attempted = '+API_URL)
                break
            if response.status_code==503:
                print(response.json())
                continue
            if response.status_code==504:
                print('504 Gateway Timeout')
            else:
                print('Unsuccessful request, status code '+ str(response.status_code))
                # print(response.json()) #debug only
                print(payload)

def generate_text(prompt, model_path, text_generation_parameters, headers):
    start_time = time.time()
    options = {'use_cache': False, 'wait_for_model': True}
    payload = {"inputs": prompt, "parameters": text_generation_parameters, "options": options}
    output_list = query(payload, model_path, headers)
    if not output_list:
        print('Generation failed')
    end_time = time.time()
    duration = round(end_time - start_time, 1)
    stringlist = []
    if output_list and 'generated_text' in output_list[0].keys():
        print(f'{len(output_list)} sample(s) of text generated in {duration} seconds.')
        for gendict in output_list:
            stringlist.append(gendict['generated_text'])
    else:
        print(output_list)
    return(stringlist)

model_path = "Colby/StarCoder-1B-WoW-JSON"
parameters = {
    "max_new_tokens": 250,
    "return_full_text": False,
    "do_sample": True,
    "temperature": 0.8,
    "top_p": 0.9,
    "top_k": 50,
    "repetition_penalty": 1.1
}
headers = {"Authorization": "Bearer " + os.environ['HF_TOKEN']}

def merlin_chat(message, history):
    chat_text = ""
    chat_json = ""
    for turn in history:
        chat_text += f"USER: {turn[0]}\n\nASSISTANT: {turn[1]}\n\n"
        chat_json += json.dumps({"role": "user", "content": turn[0]})
        chat_json += json.dumps({"role": "assistant", "content": turn[1]})
    chat_text += f"USER: {message}\n"
    doc = nlp(chat_text)
    ents_found = []
    if doc.ents:
        for ent in doc.ents:
            if len(ents_found) == 3:
                break
            if ent.text.isnumeric() or ent.label in ["DATE","TIME","PERCENT","MONEY","QUANTITY","ORDINAL","CARDINAL"]:
                continue
            if ent.text in ents_found:
                continue
            ents_found.append(ent.text.title())
    r.extract_keywords_from_text(chat_text)
    ents_found = ents_found + r.get_ranked_phrases()[:3]
    context = ""
    scores = topic_model(chat_text, ents_found, multi_label=True)['scores']
    if ents_found:
        max_score = 0
        for k in range(len(ents_found)):
            if scores[k] < 0.5:
                continue
            entity = ents_found[k]
            if scores[k] > max_score:
                max_score = scores[k]
                max_topic = entity
            print(f'# Looking up {entity} on Wikipedia... ', end='')
            wiki_page = wiki_wiki.page(entity)
            if wiki_page.exists():
                print("page found... ")
                entsum = wiki_page.summary
                if "may refer to" in entsum or "may also refer to" in entsum:
                    print(" ambiguous, skipping.")
                    continue
                else:
                    context += entsum + '\n\n'
    context
    system_msg = {
        'role': 'system', 'content': context
    }
    user_msg = {'role': 'user', 'content': message}
    prompt = "[" + json.dumps(system_msg) + chat_json + json.dumps(user_msg) + "{'role': 'assistant, 'content': '"
    for attempt in range(3):
        # result = model(prompt, max_new_tokens=250, return_full_text=False, handle_long_generation="hole")
        result = generate_text(prompt, model_path, parameters, headers)
        response = result[0]['generated_text']
        start = 0
        end = 0
        cleanStr = response.lstrip()
        # cleanStr = cleanStr.replace(prompt,"")
        start = cleanStr.find('{') # this should skip over whatever it recalls to what it says next
        if start<=0:
            continue
        cleanStr = cleanStr[start:]
        end = cleanStr.find('}') + 1
        if end<=0:
            continue
        cleanStr = cleanStr[:end]
        messageStr = prompt + cleanStr + "]"
        messages = json.loads(cleanStr)
        message = messages[-1]
        if message['role'] != 'assistant':
            continue
        return message['content']
    return "🤔"

gr.ChatInterface(merlin_chat).launch()