tushar-r-pawar commited on
Commit
e8254f1
·
verified ·
1 Parent(s): d841c1c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -18
app.py CHANGED
@@ -1,37 +1,30 @@
1
  import torch
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import streamlit as st
4
- import airllm
5
  import os
6
  from dotenv import load_dotenv
 
7
 
 
8
  load_dotenv()
9
 
10
  # Retrieve the API token from the environment variables
11
  api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
12
 
13
- # Load GEMMA 27B model and tokenizer using the API token
14
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it", use_auth_token=api_token)
15
- model = AutoModelForCausalLM.from_pretrained(
16
- "google/gemma-2-9b-it",
17
- device_map="auto",
18
- torch_dtype=torch.bfloat16,
19
- use_auth_token=api_token
20
- )
21
-
22
- # Initialize AirLLM
23
- air_llm = airllm.AutoModel()
24
 
25
  # Streamlit app configuration
26
  st.set_page_config(
27
- page_title="Chatbot with GEMMA 27B and AirLLM",
28
  page_icon="🤖",
29
  layout="wide",
30
  initial_sidebar_state="expanded",
31
  )
32
 
33
  # App title
34
- st.title("Conversational Chatbot with GEMMA 27B and AirLLM")
35
 
36
  # Sidebar configuration
37
  st.sidebar.header("Chatbot Configuration")
@@ -73,8 +66,23 @@ elif theme == "Light":
73
  user_input = st.text_input("You: ", "")
74
  if st.button("Send"):
75
  if user_input:
76
- # Generate response using AirLLM
77
- response = air_llm.generate_response(model, tokenizer, user_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  st.text_area("Bot:", value=response, height=200, max_chars=None)
79
  else:
80
  st.warning("Please enter a message.")
@@ -83,6 +91,6 @@ if st.button("Send"):
83
  st.sidebar.markdown(
84
  """
85
  ### About
86
- This is a conversational chatbot built using the base version of the GEMMA 27B model and AirLLM.
87
  """
88
  )
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import streamlit as st
 
4
  import os
5
  from dotenv import load_dotenv
6
+ from airllm import AutoModel
7
 
8
+ # Load environment variables
9
  load_dotenv()
10
 
11
  # Retrieve the API token from the environment variables
12
  api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
13
 
14
+ # Initialize model and tokenizer using the AutoModel from AirLLM
15
+ MAX_LENGTH = 128
16
+ model = AutoModel.from_pretrained("garage-bAInd/Platypus2-70B-instruct")
 
 
 
 
 
 
 
 
17
 
18
  # Streamlit app configuration
19
  st.set_page_config(
20
+ page_title="Conversational Chatbot with Platypus2-70B and AirLLM",
21
  page_icon="🤖",
22
  layout="wide",
23
  initial_sidebar_state="expanded",
24
  )
25
 
26
  # App title
27
+ st.title("Conversational Chatbot with Platypus2-70B and AirLLM")
28
 
29
  # Sidebar configuration
30
  st.sidebar.header("Chatbot Configuration")
 
66
  user_input = st.text_input("You: ", "")
67
  if st.button("Send"):
68
  if user_input:
69
+ # Tokenize user input
70
+ input_tokens = model.tokenizer(user_input,
71
+ return_tensors="pt",
72
+ return_attention_mask=False,
73
+ truncation=True,
74
+ max_length=MAX_LENGTH,
75
+ padding=False)
76
+
77
+ # Generate response
78
+ generation_output = model.generate(
79
+ input_tokens['input_ids'].cuda(),
80
+ max_new_tokens=20,
81
+ use_cache=True,
82
+ return_dict_in_generate=True)
83
+
84
+ # Decode response
85
+ response = model.tokenizer.decode(generation_output.sequences[0])
86
  st.text_area("Bot:", value=response, height=200, max_chars=None)
87
  else:
88
  st.warning("Please enter a message.")
 
91
  st.sidebar.markdown(
92
  """
93
  ### About
94
+ This is a conversational chatbot built using the Platypus2-70B model and AirLLM.
95
  """
96
  )