calmgoose commited on
Commit
8f9d28b
Β·
1 Parent(s): 728ab84

add option to use `togethercomputer/GPT-NeoXT-Chat-Base-20B`

Browse files
Files changed (1) hide show
  1. app.py +37 -13
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import streamlit as st
3
 
4
  from langchain.embeddings import HuggingFaceInstructEmbeddings
@@ -7,6 +8,7 @@ from langchain.chains import VectorDBQA
7
  from huggingface_hub import snapshot_download
8
  from langchain import OpenAI
9
  from langchain import PromptTemplate
 
10
 
11
 
12
  BOOK_NAME = "1984"
@@ -74,9 +76,20 @@ def load_prompt(book_name, author_name):
74
 
75
 
76
  @st.experimental_singleton(show_spinner=False)
77
- def load_chain():
78
- llm = OpenAI(temperature=0.2)
79
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  chain = VectorDBQA.from_chain_type(
81
  chain_type_kwargs = {"prompt": load_prompt(book_name=BOOK_NAME, author_name=AUTHOR_NAME)},
82
  llm=llm,
@@ -89,8 +102,8 @@ def load_chain():
89
  return chain
90
 
91
 
92
- def get_answer(question):
93
- chain = load_chain()
94
  result = chain({"query": question})
95
 
96
  answer = result["result"]
@@ -124,11 +137,18 @@ def get_answer(question):
124
 
125
  ##### sidebar ####
126
  with st.sidebar:
127
- api_key = st.text_input(label = "Paste your OpenAI API key here to get started",
128
- type = "password",
129
- help = "This isn't saved πŸ™ˆ"
130
- )
131
- os.environ["OPENAI_API_KEY"] = api_key
 
 
 
 
 
 
 
132
 
133
  st.markdown("---")
134
 
@@ -151,16 +171,20 @@ ask = col2.button("Ask")
151
 
152
  if ask:
153
 
154
- if api_key is "":
155
  st.write(f"**{BOOK_NAME}:** Whoops looks like you forgot your API key buddy")
156
  st.stop()
157
  else:
158
  with st.spinner("Um... excuse me but... this can take about a minute for your first question because some stuff have to be downloaded πŸ₯ΊπŸ‘‰πŸ»πŸ‘ˆπŸ»"):
159
  try:
160
- answer, pages, extract = get_answer(question=user_input)
161
  except:
162
- st.write(f"**{BOOK_NAME}:** What\'s going on? That's not the right API key")
163
- st.stop()
 
 
 
 
164
 
165
  st.write(f"**{BOOK_NAME}:** {answer}")
166
 
 
1
  import os
2
+ from typing import Literal
3
  import streamlit as st
4
 
5
  from langchain.embeddings import HuggingFaceInstructEmbeddings
 
8
  from huggingface_hub import snapshot_download
9
  from langchain import OpenAI
10
  from langchain import PromptTemplate
11
+ from langchain.llms import HuggingFacePipeline
12
 
13
 
14
  BOOK_NAME = "1984"
 
76
 
77
 
78
  @st.experimental_singleton(show_spinner=False)
79
+ def load_chain(model: Literal["openai", "GPT-NeoXT-Chat-Base-20B"] ="openai"):
 
80
 
81
+ # choose model
82
+ if model=="openai":
83
+ llm = OpenAI(temperature=0.2)
84
+
85
+ if model=="GPT-NeoXT-Chat-Base-20B":
86
+ llm = HuggingFacePipeline.from_model_id(
87
+ model_id="togethercomputer/GPT-NeoXT-Chat-Base-20B",
88
+ task="text-generation",
89
+ model_kwargs={"temperature":0.2, "max_length":400}
90
+ )
91
+
92
+ # load chain
93
  chain = VectorDBQA.from_chain_type(
94
  chain_type_kwargs = {"prompt": load_prompt(book_name=BOOK_NAME, author_name=AUTHOR_NAME)},
95
  llm=llm,
 
102
  return chain
103
 
104
 
105
+ def get_answer(question, model="openai"):
106
+ chain = load_chain(model=model)
107
  result = chain({"query": question})
108
 
109
  answer = result["result"]
 
137
 
138
  ##### sidebar ####
139
  with st.sidebar:
140
+
141
+ choice= st.radio("Choose your API:",
142
+ ["OpenAI", "togethercomputer/GPT-NeoXT-Chat-Base-20B"],
143
+ help="GPT-NeoXT-Chat-Base-20B doesn't need an API Key"
144
+ )
145
+
146
+ if choice == "OpenAI":
147
+ api_key = st.text_input(label = "Paste your OpenAI API key here to get started",
148
+ type = "password",
149
+ help = "This isn't saved πŸ™ˆ"
150
+ )
151
+ os.environ["OPENAI_API_KEY"] = api_key
152
 
153
  st.markdown("---")
154
 
 
171
 
172
  if ask:
173
 
174
+ if choice=="OpenAI" and api_key is "":
175
  st.write(f"**{BOOK_NAME}:** Whoops looks like you forgot your API key buddy")
176
  st.stop()
177
  else:
178
  with st.spinner("Um... excuse me but... this can take about a minute for your first question because some stuff have to be downloaded πŸ₯ΊπŸ‘‰πŸ»πŸ‘ˆπŸ»"):
179
  try:
180
+ answer, pages, extract = get_answer(question=user_input, model=choice)
181
  except:
182
+ if choice=="togethercomputer/GPT-NeoXT-Chat-Base-20B":
183
+ st.write("The model probably timed out :(")
184
+ st.stop()
185
+ else:
186
+ st.write(f"**{BOOK_NAME}:** What\'s going on? That's not the right API key")
187
+ st.stop()
188
 
189
  st.write(f"**{BOOK_NAME}:** {answer}")
190