File size: 4,905 Bytes
1a94aaa
 
 
1511e8d
1a94aaa
 
 
 
 
7212a44
 
d796187
7212a44
1a94aaa
 
 
adca78b
 
1a94aaa
 
 
 
1511e8d
 
 
 
 
 
1a94aaa
 
 
03c5aaf
03c3dac
 
 
bea12d1
03c3dac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9e1efb
03c3dac
1511e8d
1a94aaa
 
3b513a4
03c5aaf
714170e
3b513a4
 
e99b177
1a94aaa
 
 
 
 
 
 
 
 
 
03c5aaf
1a94aaa
03c5aaf
 
 
 
1a94aaa
c2d580a
1a94aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03c5aaf
 
1a94aaa
9496358
1a94aaa
3b513a4
1a94aaa
3b513a4
348444e
03c5aaf
1a94aaa
03c5aaf
03c3dac
82151fa
d9e1efb
 
03c5aaf
1a94aaa
 
 
348444e
 
 
54033c4
1a94aaa
 
348444e
d9e1efb
54033c4
 
 
 
03c5aaf
 
 
54033c4
7212a44
1a94aaa
b97fa02
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import json

import gradio as gr
import os

import spacy
spacy.cli.download('en_core_web_sm')
nlp = spacy.load('en_core_web_sm')

import nltk
nltk.download('stopwords')
nltk.download('punkt')

from rake_nltk import Rake
r = Rake()

import time

import wikipediaapi
wiki_wiki = wikipediaapi.Wikipedia('Organika ([email protected])', 'en')

## ctransformers disabled for now
# from ctransformers import AutoModelForCausalLM
# model = AutoModelForCausalLM.from_pretrained(
#     "Colby/StarCoder-3B-WoW-JSON", 
#     model_file="StarCoder-3B-WoW-JSON-ggml.bin",
#     model_type="gpt_bigcode"
#    )

# Use a pipeline as a high-level helper
from transformers import pipeline
topic_model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9", device=0)
#model = pipeline("text-generation", model="Organika/StarCoder-7B-WoW-JSON_1", device=0)

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)

model_name = "umm-maybe/StarCoder-7B-WoW-JSON_1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, device_map="auto")

def generate_text(prompt):
    inputs = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
    outputs = model.generate(
        inputs,
        do_sample=True,
        max_new_tokens=200,
        temperature=0.6,
        top_p=0.9,
        top_k=40,
        repetition_penalty=1.1
        )
    results = tokenizer.decode(outputs[0], clean_up_tokenization_spaces=False)
    return results

def merlin_chat(message, history):
    chat_text = ""
    chat_list = []
    for turn in history[-3:]:
        chat_text += f"{turn[0]}\n\n{turn[1]}\n\n"
        chat_list.append({"role": "user", "content": turn[0]})
        chat_list.append({"role": "assistant", "content": turn[1]})
    chat_text += f"{message}\n"
    doc = nlp(chat_text)
    ents_found = []
    if doc.ents:
        for ent in doc.ents:
            if len(ents_found) == 3:
                break
            if ent.text.isnumeric() or ent.label in ["DATE","TIME","PERCENT","MONEY","QUANTITY","ORDINAL","CARDINAL"]:
                continue
            if ent.text in ents_found:
                continue
            ents_found.append(ent.text.title().lower())
    r.extract_keywords_from_text(chat_text)
    for phrase in r.get_ranked_phrases()[:3]:
        phrase = phrase.lower()
        if phrase not in ents_found:
            ents_found.append(phrase)
    context = ""
    scores = topic_model(chat_text, ents_found, multi_label=True)['scores']
    if ents_found:
        max_score = 0
        for k in range(len(ents_found)):
            if scores[k] < 0.5:
                continue
            entity = ents_found[k]
            if scores[k] > max_score:
                max_score = scores[k]
                max_topic = entity
            print(f'# Looking up {entity} on Wikipedia... ', end='')
            wiki_page = wiki_wiki.page(entity)
            if wiki_page.exists():
                print("page found... ")
                entsum = wiki_page.summary
                if "may refer to" in entsum or "may also refer to" in entsum:
                    print(" ambiguous, skipping.")
                    continue
                else:
                    context += entsum + '\n\n'
            else:
                print("not found.")
    system_msg = {
        'role': 'system', 'content': context
    }
    chat_list.insert(0,system_msg)
    user_msg = {'role': 'user', 'content': message}
    chat_list.append(user_msg)
    prompt = json.dumps(chat_list)[:-1] + ",{\"role\": \"assistant\", \"content\": \""
    print(f"PROMPT: {prompt}")
    for attempt in range(3):
        #result = generate_text(prompt, model_path, parameters, headers)
        #result = model(prompt,return_full_text=False, max_new_tokens=256, temperature=0.8, repetition_penalty=1.1)
        #response = result[0]['generated_text']
        result = generate_text(prompt)
        response = result.replace(prompt,"")
        print(f"COMPLETION: {response}") # so we can see it in logs
        start = 0
        end = 0
        cleanStr = response.lstrip()
        #start = cleanStr.find('{')
        #if start<=0:
        #    continue
        end = cleanStr.find('}') + 1
        if end<=0:
            continue
        cleanStr = cleanStr[:end]
        messageStr = prompt + cleanStr + ']'
        messages = json.loads(messageStr)
        message = messages[-1]
        if message['role'] != 'assistant':
            continue
        msg_text = message['content']
        if chat_text.find(msg_text) >= 0:
            continue
        return message['content']
    return "🤔"

gr.ChatInterface(merlin_chat).launch()