sanjeevbora commited on
Commit
618013c
·
verified ·
1 Parent(s): 8142dcf

updated app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -17
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  from langchain.embeddings import HuggingFaceEmbeddings
3
  from langchain.vectorstores import Chroma
@@ -9,6 +10,7 @@ import torch
9
  import re
10
  import transformers
11
  import spaces
 
12
 
13
  # Initialize embeddings and ChromaDB
14
  model_name = "sentence-transformers/all-mpnet-base-v2"
@@ -24,23 +26,13 @@ books_db_client = books_db.as_retriever()
24
 
25
  # Initialize the model and tokenizer
26
  model_name = "stabilityai/stablelm-zephyr-3b"
27
-
28
- # bnb_config = transformers.BitsAndBytesConfig(
29
- # load_in_4bit=True,
30
- # bnb_4bit_quant_type='nf4',
31
- # bnb_4bit_use_double_quant=True,
32
- # bnb_4bit_compute_dtype=torch.bfloat16
33
- # )
34
-
35
  model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024)
36
  model = transformers.AutoModelForCausalLM.from_pretrained(
37
  model_name,
38
  trust_remote_code=True,
39
  config=model_config,
40
- # quantization_config=bnb_config,
41
  device_map=device,
42
  )
43
-
44
  tokenizer = AutoTokenizer.from_pretrained(model_name)
45
 
46
  query_pipeline = transformers.pipeline(
@@ -70,8 +62,6 @@ books_db_client_retriever = RetrievalQA.from_chain_type(
70
  @spaces.GPU(duration=60)
71
  def test_rag(query):
72
  books_retriever = books_db_client_retriever.run(query)
73
-
74
- # Extract the relevant answer using regex
75
  corrected_text_match = re.search(r"Helpful Answer:(.*)", books_retriever, re.DOTALL)
76
 
77
  if corrected_text_match:
@@ -81,6 +71,14 @@ def test_rag(query):
81
 
82
  return corrected_text_books
83
 
 
 
 
 
 
 
 
 
84
  # Define the Gradio interface
85
  def chat(query, history=None):
86
  if history is None:
@@ -101,11 +99,12 @@ with gr.Blocks() as interface:
101
 
102
  input_box = gr.Textbox(label="Enter your question", placeholder="Type your question here...")
103
  submit_btn = gr.Button("Submit")
104
- # clear_btn = gr.Button("Clear")
105
  chat_history = gr.Chatbot(label="Chat History")
106
-
107
- submit_btn.click(chat, inputs=[input_box, chat_history], outputs=[chat_history, input_box])
108
- # clear_btn.click(clear_input, outputs=input_box)
109
 
110
- interface.launch()
 
 
 
 
111
 
 
 
1
+ import os
2
  import gradio as gr
3
  from langchain.embeddings import HuggingFaceEmbeddings
4
  from langchain.vectorstores import Chroma
 
10
  import re
11
  import transformers
12
  import spaces
13
+ import requests
14
 
15
  # Initialize embeddings and ChromaDB
16
  model_name = "sentence-transformers/all-mpnet-base-v2"
 
26
 
27
  # Initialize the model and tokenizer
28
  model_name = "stabilityai/stablelm-zephyr-3b"
 
 
 
 
 
 
 
 
29
  model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024)
30
  model = transformers.AutoModelForCausalLM.from_pretrained(
31
  model_name,
32
  trust_remote_code=True,
33
  config=model_config,
 
34
  device_map=device,
35
  )
 
36
  tokenizer = AutoTokenizer.from_pretrained(model_name)
37
 
38
  query_pipeline = transformers.pipeline(
 
62
  @spaces.GPU(duration=60)
63
  def test_rag(query):
64
  books_retriever = books_db_client_retriever.run(query)
 
 
65
  corrected_text_match = re.search(r"Helpful Answer:(.*)", books_retriever, re.DOTALL)
66
 
67
  if corrected_text_match:
 
71
 
72
  return corrected_text_books
73
 
74
+ # OAuth Login Functionality
75
+ def oauth_login():
76
+ client_id = os.getenv("OAUTH_CLIENT_ID")
77
+ redirect_uri = f"https://{os.getenv('SPACE_HOST')}/login/callback"
78
+ state = "random_string" # Ideally generate a secure random string
79
+ login_url = f"https://huggingface.co/oauth/authorize?redirect_uri={redirect_uri}&scope=openid%20profile&client_id={client_id}&state={state}"
80
+ return login_url
81
+
82
  # Define the Gradio interface
83
  def chat(query, history=None):
84
  if history is None:
 
99
 
100
  input_box = gr.Textbox(label="Enter your question", placeholder="Type your question here...")
101
  submit_btn = gr.Button("Submit")
 
102
  chat_history = gr.Chatbot(label="Chat History")
 
 
 
103
 
104
+ # Sign-In Button
105
+ login_btn = gr.Button("Sign In with HF")
106
+ login_btn.click(lambda: oauth_login(), outputs=None) # Redirect user for OAuth login
107
+
108
+ submit_btn.click(chat, inputs=[input_box, chat_history], outputs=[chat_history, input_box])
109
 
110
+ interface.launch()