CogwiseAI commited on
Commit
cff08e0
·
1 Parent(s): f4044e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -21
app.py CHANGED
@@ -1,7 +1,9 @@
 
1
  import streamlit as st
2
  import uuid
3
  import sys
4
  import requests
 
5
  import bitsandbytes as bnb
6
  import pandas as pd
7
  import torch
@@ -9,6 +11,12 @@ import torch.nn as nn
9
  import transformers
10
  from datasets import load_dataset
11
  from huggingface_hub import notebook_login
 
 
 
 
 
 
12
  from transformers import (
13
  AutoConfig,
14
  AutoModelForCausalLM,
@@ -66,13 +74,11 @@ st.markdown("""
66
  </style>
67
  """, unsafe_allow_html=True)
68
 
69
- # Load the model and tokenizer from Hugging Face Hub
70
- model_name = "tiiuae/falcon-7b-instruct"
71
- model = AutoModelForCausalLM.from_pretrained(model_name)
72
- tokenizer = AutoTokenizer.from_pretrained(model_name)
73
-
74
- # Load the dataset
75
- dataset = load_dataset("nisaar/Lawyer_GPT_India")
76
 
77
  def write_top_bar():
78
  col1, col2, col3 = st.columns([1,10,2])
@@ -105,13 +111,8 @@ def handle_input():
105
  if len(chat_history) == MAX_HISTORY_LENGTH:
106
  chat_history = chat_history[:-1]
107
 
108
- # Find the most similar example in the dataset
109
- closest_example = find_closest_example(input, dataset) # Implement your own logic to find the closest example
110
-
111
- # Generate response using the model
112
- inputs = tokenizer.encode(closest_example, return_tensors="pt")
113
- outputs = model.generate(inputs)
114
- answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
115
 
116
  chat_history.append((input, answer))
117
 
@@ -121,12 +122,6 @@ def handle_input():
121
  })
122
  st.session_state.input = ""
123
 
124
- def find_closest_example(input, dataset):
125
- # Implement your own logic to find the closest example in the dataset based on the user input
126
- # You can use techniques like cosine similarity, semantic similarity, or any other approach that fits your dataset and requirements
127
- # Return the closest example as a string
128
- pass
129
-
130
  def write_user_message(md):
131
  col1, col2 = st.columns([1,12])
132
 
@@ -153,4 +148,4 @@ with st.container():
153
  write_chat_message(a, q)
154
 
155
  st.markdown('---')
156
- input = st.text_input("You are talking to an AI, ask any question.", key="input", on_change=handle_input)
 
1
+
2
  import streamlit as st
3
  import uuid
4
  import sys
5
  import requests
6
+ from peft import *
7
  import bitsandbytes as bnb
8
  import pandas as pd
9
  import torch
 
11
  import transformers
12
  from datasets import load_dataset
13
  from huggingface_hub import notebook_login
14
+ from peft import (
15
+ LoraConfig,
16
+ PeftConfig,
17
+ get_peft_model,
18
+ prepare_model_for_kbit_training,
19
+ )
20
  from transformers import (
21
  AutoConfig,
22
  AutoModelForCausalLM,
 
74
  </style>
75
  """, unsafe_allow_html=True)
76
 
77
+ # Load the model outside the handle_input() function
78
+ with open('model_saved.pkl', 'rb') as f:
79
+ model = pickle.load(f)
80
+ if not isinstance(model, str):
81
+ st.error("The loaded model is not valid.")
 
 
82
 
83
  def write_top_bar():
84
  col1, col2, col3 = st.columns([1,10,2])
 
111
  if len(chat_history) == MAX_HISTORY_LENGTH:
112
  chat_history = chat_history[:-1]
113
 
114
+ prompt = input
115
+ answer = model # Replace the predict() method with the model itself
 
 
 
 
 
116
 
117
  chat_history.append((input, answer))
118
 
 
122
  })
123
  st.session_state.input = ""
124
 
 
 
 
 
 
 
125
  def write_user_message(md):
126
  col1, col2 = st.columns([1,12])
127
 
 
148
  write_chat_message(a, q)
149
 
150
  st.markdown('---')
151
+ input = st.text_input("You are talking to an AI, ask any question.", key="input", on_change=handle_input)