louiismiro commited on
Commit
d84e01b
·
verified ·
1 Parent(s): 325d93b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -17
app.py CHANGED
@@ -1,31 +1,36 @@
1
- # Import required libraries
2
- import gradio as gr
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
4
 
5
- # Load the model and tokenizer
6
  MODEL_NAME = "SeaLLMs/SeaLLM-7B-v2.5"
7
 
8
- # Download model and tokenizer from Hugging Face
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype="auto", device_map="auto")
 
 
 
 
11
 
12
- # Define the chatbot function
13
- def chatbot(user_input):
14
- inputs = tokenizer(user_input, return_tensors="pt").to("cuda")
15
- outputs = model.generate(inputs["input_ids"], max_length=150, num_return_sequences=1, temperature=0.7)
 
 
 
16
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
17
  return response
18
 
19
- # Create a Gradio interface
20
- interface = gr.Interface(
21
  fn=chatbot,
22
- inputs="text",
23
- outputs="text",
24
  title="SeaLLM Chatbot",
25
- description="A chatbot powered by SeaLLM-7B-v2.5.",
26
- examples=["Hello!", "What's the weather today?", "Tell me a joke!"],
27
  )
28
 
29
- # Launch the interface
30
  if __name__ == "__main__":
31
- interface.launch()
 
 
 
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ import gradio as gr
3
+ import torch
4
 
5
+ # Define model name
6
  MODEL_NAME = "SeaLLMs/SeaLLM-7B-v2.5"
7
 
8
+ # Load the model and tokenizer with optimized settings
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ MODEL_NAME,
12
+ torch_dtype=torch.float16, # Use float16 for GPU optimization
13
+ device_map="auto" # Automatically assign to available GPUs
14
+ )
15
 
16
+ # Chatbot function
17
+ def chatbot(prompt):
18
+ # Tokenize input
19
+ inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
20
+ # Generate response
21
+ outputs = model.generate(inputs.input_ids, max_new_tokens=150, temperature=0.7)
22
+ # Decode and return response
23
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
24
  return response
25
 
26
+ # Gradio Interface
27
+ iface = gr.Interface(
28
  fn=chatbot,
29
+ inputs=gr.Textbox(label="Ask me anything:", lines=3, placeholder="Type your message here..."),
30
+ outputs=gr.Textbox(label="Response"),
31
  title="SeaLLM Chatbot",
32
+ description="A chatbot powered by SeaLLM-7B-v2.5 for text generation.",
 
33
  )
34
 
 
35
  if __name__ == "__main__":
36
+ iface.launch()