File size: 3,630 Bytes
1a94aaa
 
 
 
 
 
 
 
7212a44
 
d796187
7212a44
1a94aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2d580a
1a94aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7212a44
 
 
 
 
 
 
1a94aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7212a44
1a94aaa
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import json

import gradio as gr

import spacy
spacy.cli.download('en_core_web_sm')
nlp = spacy.load('en_core_web_sm')

import nltk
nltk.download('stopwords')
nltk.download('punkt')

from rake_nltk import Rake
r = Rake()

import wikipediaapi
wiki_wiki = wikipediaapi.Wikipedia('Organika ([email protected])', 'en')

## ctransformers disabled for now
# from ctransformers import AutoModelForCausalLM
# model = AutoModelForCausalLM.from_pretrained(
#     "Colby/StarCoder-3B-WoW-JSON", 
#     model_file="StarCoder-3B-WoW-JSON-ggml.bin",
#     model_type="gpt_bigcode"
#     )

# Use a pipeline as a high-level helper
from transformers import pipeline
topic_model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9")
model = pipeline("text-generation", model="Colby/StarCoder-3B-WoW-JSON")

def merlin_chat(message, history):
    chat_text = ""
    chat_json = ""
    for turn in history:
        chat_text += f"USER: {turn[0]}\n\nASSISTANT: {turn[1]}\n\n"
        chat_json += json.dumps({"role": "user", "content": turn[0]})
        chat_json += json.dumps({"role": "assistant", "content": turn[1]})
    chat_text += f"USER: {message}\n"
    doc = nlp(chat_text)
    ents_found = []
    if doc.ents:
        for ent in doc.ents:
            if len(ents_found) == 3:
                break
            if ent.text.isnumeric() or ent.label in ["DATE","TIME","PERCENT","MONEY","QUANTITY","ORDINAL","CARDINAL"]:
                continue
            if ent.text in ents_found:
                continue
            ents_found.append(ent.text.title())
    r.extract_keywords_from_text(chat_text)
    ents_found = ents_found + r.get_ranked_phrases()[:3]
    context = ""
    scores = topic_model(chat_text, ents_found, multi_label=True)['scores']
    if ents_found:
        max_score = 0
        for k in range(len(ents_found)):
            if scores[k] < 0.5:
                continue
            entity = ents_found[k]
            if scores[k] > max_score:
                max_score = scores[k]
                max_topic = entity
            print(f'# Looking up {entity} on Wikipedia... ', end='')
            wiki_page = wiki_wiki.page(entity)
            if wiki_page.exists():
                print("page found... ")
                entsum = wiki_page.summary
                if "may refer to" in entsum or "may also refer to" in entsum:
                    print(" ambiguous, skipping.")
                    continue
                else:
                    context += entsum + '\n\n'
    context
    system_msg = {
        'role': 'system', 'content': context + f'\n\nThe following is a conversation about {max_topic}.'
    }
    user_msg = {'role': 'user', 'content': message}
    prompt = "[" + json.dumps(system_msg) + chat_json + json.dumps(user_msg) + "{'role': 'assistant, 'content': '*recalls \""
    for attempt in range(3):
        result = model(
            prompt,
            return_full_text=False,
            max_length=250, 
            handle_long_generation="hole"
            )
        response = result[0]['generated_text']
        start = 0
        end = 0
        cleanStr = response.lstrip()
        start = cleanStr.find('{') # this should skip over whatever it recalls to what it says next
        if start<=0:
            continue
        cleanStr = cleanStr[start:]
        end = cleanStr.find('}') + 1
        if end<=0:
            continue
        cleanStr = cleanStr[:end]
        message = json.loads(cleanStr)
        if message['role'] != 'assistant':
            continue
        return message['content']
    return "🤔"

gr.ChatInterface(merlin_chat).launch()