Colby commited on
Commit
1a94aaa
·
verified ·
1 Parent(s): 675d880

Upload 2 files

Browse files

Add app and requirements

Files changed (2) hide show
  1. app.py +95 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import gradio as gr
4
+
5
+ import spacy
6
+ spacy.cli.download('en_core_web_sm')
7
+ nlp = spacy.load('en_core_web_sm')
8
+
9
+ from rake_nltk import Rake
10
+ r = Rake()
11
+
12
+ import wikipediaapi
13
+ wiki_wiki = wikipediaapi.Wikipedia('Organika ([email protected])', 'en')
14
+
15
+ ## ctransformers disabled for now
16
+ # from ctransformers import AutoModelForCausalLM
17
+ # model = AutoModelForCausalLM.from_pretrained(
18
+ # "Colby/StarCoder-3B-WoW-JSON",
19
+ # model_file="StarCoder-3B-WoW-JSON-ggml.bin",
20
+ # model_type="gpt_bigcode"
21
+ # )
22
+
23
+ # Use a pipeline as a high-level helper
24
+ from transformers import pipeline
25
+ topic_model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9")
26
+ model = pipeline("text-generation", model="Colby/StarCoder-3B-WoW-JSON")
27
+
28
+ def merlin_chat(message, history):
29
+ chat_text = ""
30
+ chat_json = ""
31
+ for turn in history:
32
+ chat_text += f"USER: {turn[0]}\n\nASSISTANT: {turn[1]}\n\n"
33
+ chat_json += json.dumps({"role": "user", "content": turn[0]})
34
+ chat_json += json.dumps({"role": "assistant", "content": turn[1]})
35
+ chat_text += f"USER: {message}\n"
36
+ doc = nlp(chat_text)
37
+ ents_found = []
38
+ if doc.ents:
39
+ for ent in doc.ents:
40
+ if len(ents_found) == 3:
41
+ break
42
+ if ent.text.isnumeric() or ent.label in ["DATE","TIME","PERCENT","MONEY","QUANTITY","ORDINAL","CARDINAL"]:
43
+ continue
44
+ if ent.text in ents_found:
45
+ continue
46
+ ents_found.append(ent.text.title())
47
+ r.extract_keywords_from_text(chat_text)
48
+ ents_found = ents_found + r.get_ranked_phrases()[:3]
49
+ context = ""
50
+ scores = topic_model(context, ents_found, multi_label=True)['scores']
51
+ if ents_found:
52
+ max_score = 0
53
+ for k in range(len(ents_found)):
54
+ if scores[k] < 0.5:
55
+ continue
56
+ entity = ents_found[k]
57
+ if scores[k] > max_score:
58
+ max_score = scores[k]
59
+ max_topic = entity
60
+ print(f'# Looking up {entity} on Wikipedia... ', end='')
61
+ wiki_page = wiki_wiki.page(entity)
62
+ if wiki_page.exists():
63
+ print("page found... ")
64
+ entsum = wiki_page.summary
65
+ if "may refer to" in entsum or "may also refer to" in entsum:
66
+ print(" ambiguous, skipping.")
67
+ continue
68
+ else:
69
+ context += entsum + '\n\n'
70
+ context
71
+ system_msg = {
72
+ 'role': 'system', 'content': context + f'\n\nThe following is a conversation about {max_topic}.'
73
+ }
74
+ user_msg = {'role': 'user', 'content': message}
75
+ prompt = "[" + json.dumps(system_msg) + chat_json + json.dumps(user_msg) + "{'role': 'assistant, 'content': '*recalls \""
76
+ for attempt in range(3):
77
+ response = model(prompt, max_new_tokens=250, stop=["]"])
78
+ start = 0
79
+ end = 0
80
+ cleanStr = response.lstrip()
81
+ start = cleanStr.find('{') # this should skip over whatever it recalls to what it says next
82
+ if start<=0:
83
+ continue
84
+ cleanStr = cleanStr[start:]
85
+ end = cleanStr.find('}') + 1
86
+ if end<=0:
87
+ continue
88
+ cleanStr = cleanStr[:end]
89
+ message = json.loads(cleanStr)
90
+ if message['role'] != 'assistant':
91
+ continue
92
+ return message['content']
93
+ return "Sorry, I don't know what to say."
94
+
95
+ gr.ChatInterface(merlin_chat).launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ spacy
4
+ rake_nltk
5
+ wikipedia-api