tushar-r-pawar commited on
Commit
54ac477
·
verified ·
1 Parent(s): 2971ea5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import streamlit as st
4
  import os
5
  from dotenv import load_dotenv
6
- from airllm import AutoModel
 
7
 
8
  # Load environment variables
9
  load_dotenv()
@@ -11,13 +11,15 @@ load_dotenv()
11
  # Retrieve the API token from the environment variables
12
  api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
13
 
14
- # Initialize model and tokenizer using the AutoModel from AirLLM
15
  MAX_LENGTH = 128
16
- model = AutoModel.from_pretrained("internlm/internlm2_5-7b")
 
 
17
 
18
  # Streamlit app configuration
19
  st.set_page_config(
20
- page_title="Conversational Chatbot with internlm2_5-7b-chat and AirLLM",
21
  page_icon="🤖",
22
  layout="wide",
23
  initial_sidebar_state="expanded",
@@ -67,22 +69,27 @@ user_input = st.text_input("You: ", "")
67
  if st.button("Send"):
68
  if user_input:
69
  # Tokenize user input
70
- input_tokens = model.tokenizer(user_input,
71
  return_tensors="pt",
72
  return_attention_mask=False,
73
  truncation=True,
74
  max_length=MAX_LENGTH,
75
  padding=False)
76
 
 
 
 
 
 
77
  # Generate response
78
  generation_output = model.generate(
79
- input_tokens['input_ids'].cuda(),
80
  max_new_tokens=20,
81
  use_cache=True,
82
  return_dict_in_generate=True)
83
 
84
  # Decode response
85
- response = model.tokenizer.decode(generation_output.sequences[0])
86
  st.text_area("Bot:", value=response, height=200, max_chars=None)
87
  else:
88
  st.warning("Please enter a message.")
 
1
  import torch
 
2
  import streamlit as st
3
  import os
4
  from dotenv import load_dotenv
5
+ from airllm import AirLLMInternLM
6
+ from transformers import AutoTokenizer, GenerationConfig
7
 
8
  # Load environment variables
9
  load_dotenv()
 
11
  # Retrieve the API token from the environment variables
12
  api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
13
 
14
+ # Initialize model and tokenizer
15
  MAX_LENGTH = 128
16
+ model_name = "internlm/internlm2_5-7b"
17
+ model = AirLLMInternLM.from_pretrained(model_name)
18
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
19
 
20
  # Streamlit app configuration
21
  st.set_page_config(
22
+ page_title="Conversational Chatbot with internlm2_5-7b-chat",
23
  page_icon="🤖",
24
  layout="wide",
25
  initial_sidebar_state="expanded",
 
69
  if st.button("Send"):
70
  if user_input:
71
  # Tokenize user input
72
+ input_tokens = tokenizer(user_input,
73
  return_tensors="pt",
74
  return_attention_mask=False,
75
  truncation=True,
76
  max_length=MAX_LENGTH,
77
  padding=False)
78
 
79
+ # Check if CUDA is available and use it if possible
80
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
+ model.to(device)
82
+ input_tokens = input_tokens.to(device)
83
+
84
  # Generate response
85
  generation_output = model.generate(
86
+ input_ids=input_tokens['input_ids'],
87
  max_new_tokens=20,
88
  use_cache=True,
89
  return_dict_in_generate=True)
90
 
91
  # Decode response
92
+ response = tokenizer.decode(generation_output.sequences[0], skip_special_tokens=True)
93
  st.text_area("Bot:", value=response, height=200, max_chars=None)
94
  else:
95
  st.warning("Please enter a message.")