CogwiseAI commited on
Commit
f4044e9
·
1 Parent(s): 0efe80a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -16
app.py CHANGED
@@ -1,9 +1,7 @@
1
-
2
  import streamlit as st
3
  import uuid
4
  import sys
5
  import requests
6
- from peft import *
7
  import bitsandbytes as bnb
8
  import pandas as pd
9
  import torch
@@ -11,12 +9,6 @@ import torch.nn as nn
11
  import transformers
12
  from datasets import load_dataset
13
  from huggingface_hub import notebook_login
14
- from peft import (
15
- LoraConfig,
16
- PeftConfig,
17
- get_peft_model,
18
- prepare_model_for_kbit_training,
19
- )
20
  from transformers import (
21
  AutoConfig,
22
  AutoModelForCausalLM,
@@ -74,11 +66,13 @@ st.markdown("""
74
  </style>
75
  """, unsafe_allow_html=True)
76
 
77
- # Load the model outside the handle_input() function
78
- with open('model_saved.pkl', 'rb') as f:
79
- model = pickle.load(f)
80
- if not isinstance(model, str):
81
- st.error("The loaded model is not valid.")
 
 
82
 
83
  def write_top_bar():
84
  col1, col2, col3 = st.columns([1,10,2])
@@ -111,8 +105,13 @@ def handle_input():
111
  if len(chat_history) == MAX_HISTORY_LENGTH:
112
  chat_history = chat_history[:-1]
113
 
114
- prompt = input
115
- answer = model # Replace the predict() method with the model itself
 
 
 
 
 
116
 
117
  chat_history.append((input, answer))
118
 
@@ -122,6 +121,12 @@ def handle_input():
122
  })
123
  st.session_state.input = ""
124
 
 
 
 
 
 
 
125
  def write_user_message(md):
126
  col1, col2 = st.columns([1,12])
127
 
@@ -148,4 +153,4 @@ with st.container():
148
  write_chat_message(a, q)
149
 
150
  st.markdown('---')
151
- input = st.text_input("You are talking to an AI, ask any question.", key="input", on_change=handle_input)
 
 
1
  import streamlit as st
2
  import uuid
3
  import sys
4
  import requests
 
5
  import bitsandbytes as bnb
6
  import pandas as pd
7
  import torch
 
9
  import transformers
10
  from datasets import load_dataset
11
  from huggingface_hub import notebook_login
 
 
 
 
 
 
12
  from transformers import (
13
  AutoConfig,
14
  AutoModelForCausalLM,
 
66
  </style>
67
  """, unsafe_allow_html=True)
68
 
69
+ # Load the model and tokenizer from Hugging Face Hub
70
+ model_name = "tiiuae/falcon-7b-instruct"
71
+ model = AutoModelForCausalLM.from_pretrained(model_name)
72
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
73
+
74
+ # Load the dataset
75
+ dataset = load_dataset("nisaar/Lawyer_GPT_India")
76
 
77
  def write_top_bar():
78
  col1, col2, col3 = st.columns([1,10,2])
 
105
  if len(chat_history) == MAX_HISTORY_LENGTH:
106
  chat_history = chat_history[:-1]
107
 
108
+ # Find the most similar example in the dataset
109
+ closest_example = find_closest_example(input, dataset) # Implement your own logic to find the closest example
110
+
111
+ # Generate response using the model
112
+ inputs = tokenizer.encode(closest_example, return_tensors="pt")
113
+ outputs = model.generate(inputs)
114
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
115
 
116
  chat_history.append((input, answer))
117
 
 
121
  })
122
  st.session_state.input = ""
123
 
124
+ def find_closest_example(input, dataset):
125
+ # Implement your own logic to find the closest example in the dataset based on the user input
126
+ # You can use techniques like cosine similarity, semantic similarity, or any other approach that fits your dataset and requirements
127
+ # Return the closest example as a string
128
+ pass
129
+
130
  def write_user_message(md):
131
  col1, col2 = st.columns([1,12])
132
 
 
153
  write_chat_message(a, q)
154
 
155
  st.markdown('---')
156
+ input = st.text_input("You are talking to an AI, ask any question.", key="input", on_change=handle_input)