File size: 3,967 Bytes
1a94aaa
 
 
1511e8d
1a94aaa
 
 
 
 
7212a44
 
d796187
7212a44
1a94aaa
 
 
adca78b
 
1a94aaa
 
 
 
1511e8d
 
 
 
 
 
1a94aaa
 
 
03c5aaf
 
1511e8d
1a94aaa
 
3b513a4
03c5aaf
714170e
3b513a4
 
e99b177
1a94aaa
 
 
 
 
 
 
 
 
 
03c5aaf
1a94aaa
03c5aaf
 
 
 
1a94aaa
c2d580a
1a94aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03c5aaf
 
1a94aaa
9496358
1a94aaa
3b513a4
1a94aaa
3b513a4
348444e
03c5aaf
1a94aaa
03c5aaf
 
 
 
1a94aaa
 
 
348444e
 
 
54033c4
1a94aaa
 
348444e
54033c4
 
 
 
 
03c5aaf
 
 
54033c4
7212a44
1a94aaa
03c5aaf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import json

import gradio as gr
import os

import spacy
spacy.cli.download('en_core_web_sm')
nlp = spacy.load('en_core_web_sm')

import nltk
nltk.download('stopwords')
nltk.download('punkt')

from rake_nltk import Rake
r = Rake()

import time

import wikipediaapi
wiki_wiki = wikipediaapi.Wikipedia('Organika ([email protected])', 'en')

## ctransformers disabled for now
# from ctransformers import AutoModelForCausalLM
# model = AutoModelForCausalLM.from_pretrained(
#     "Colby/StarCoder-3B-WoW-JSON", 
#     model_file="StarCoder-3B-WoW-JSON-ggml.bin",
#     model_type="gpt_bigcode"
#    )

# Use a pipeline as a high-level helper
from transformers import pipeline
topic_model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9", device=0)
model = pipeline("text-generation", model="Colby/StarCoder-1B-WoW-JSON", device=0)

def merlin_chat(message, history):
    chat_text = ""
    chat_list = []
    for turn in history[-3:]:
        chat_text += f"{turn[0]}\n\n{turn[1]}\n\n"
        chat_list.append({"role": "user", "content": turn[0]})
        chat_list.append({"role": "assistant", "content": turn[1]})
    chat_text += f"{message}\n"
    doc = nlp(chat_text)
    ents_found = []
    if doc.ents:
        for ent in doc.ents:
            if len(ents_found) == 3:
                break
            if ent.text.isnumeric() or ent.label in ["DATE","TIME","PERCENT","MONEY","QUANTITY","ORDINAL","CARDINAL"]:
                continue
            if ent.text in ents_found:
                continue
            ents_found.append(ent.text.title().lower())
    r.extract_keywords_from_text(chat_text)
    for phrase in r.get_ranked_phrases()[:3]:
        phrase = phrase.lower()
        if phrase not in ents_found:
            ents_found.append(phrase)
    context = ""
    scores = topic_model(chat_text, ents_found, multi_label=True)['scores']
    if ents_found:
        max_score = 0
        for k in range(len(ents_found)):
            if scores[k] < 0.5:
                continue
            entity = ents_found[k]
            if scores[k] > max_score:
                max_score = scores[k]
                max_topic = entity
            print(f'# Looking up {entity} on Wikipedia... ', end='')
            wiki_page = wiki_wiki.page(entity)
            if wiki_page.exists():
                print("page found... ")
                entsum = wiki_page.summary
                if "may refer to" in entsum or "may also refer to" in entsum:
                    print(" ambiguous, skipping.")
                    continue
                else:
                    context += entsum + '\n\n'
            else:
                print("not found.")
    system_msg = {
        'role': 'system', 'content': context
    }
    chat_list.insert(0,system_msg)
    user_msg = {'role': 'user', 'content': message}
    chat_list.append(user_msg)
    prompt = json.dumps(chat_list)[:-1] + ",{\"role\": \"assistant\", \"content\": \""
    print(f"PROMPT: {prompt}")
    for attempt in range(3):
        #result = generate_text(prompt, model_path, parameters, headers)
        result = model(prompt,return_full_text=False, max_new_tokens=256, temperature=0.8, repetition_penalty=1.1)
        response = result[0]['generated_text']
        print(f"COMPLETION: {response}") # so we can see it in logs
        start = 0
        end = 0
        cleanStr = response.lstrip()
        #start = cleanStr.find('{')
        #if start<=0:
        #    continue
        end = cleanStr.find('}') + 1
        if end<=0:
            continue
        cleanStr = cleanStr[:end]
        messageStr = prompt + cleanStr + ']'
        messages = json.loads(messageStr)
        message = messages[-1]
        if message['role'] != 'assistant':
            continue
        msg_text = message['content']
        if chat_text.find(msg_text) >= 0:
            continue
        return message['content']
    return "🤔"

gr.ChatInterface(merlin_chat).launch(share=True)