mrcuddle commited on
Commit
39d1b86
·
verified ·
1 Parent(s): c4df5b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -1,13 +1,18 @@
1
  import gradio as gr
2
  from vllm import LLM
3
  from vllm.sampling_params import SamplingParams
 
4
  import spaces
 
5
  # Define the model and sampling parameters
6
  model_name = "mistralai/Ministral-8B-Instruct-2410"
7
  sampling_params = SamplingParams(max_tokens=8192)
8
 
9
- # Initialize the LLM model
10
- llm = LLM(model=model_name, tokenizer_mode="mistral", config_format="mistral", load_format="mistral")
 
 
 
11
 
12
  @spaces.GPU
13
  # Define the chatbot function
@@ -46,4 +51,4 @@ with gr.Blocks() as demo:
46
  txt.submit(chatbot, [txt, chatbot], [chatbot, txt])
47
 
48
  # Launch the Gradio interface
49
- demo.launch()
 
1
  import gradio as gr
2
  from vllm import LLM
3
  from vllm.sampling_params import SamplingParams
4
+ import torch
5
  import spaces
6
+
7
  # Define the model and sampling parameters
8
  model_name = "mistralai/Ministral-8B-Instruct-2410"
9
  sampling_params = SamplingParams(max_tokens=8192)
10
 
11
+ # Check if GPU is available
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ # Initialize the LLM model with the specified device
15
+ llm = LLM(model=model_name, tokenizer_mode="mistral", config_format="mistral", load_format="mistral", device=device)
16
 
17
  @spaces.GPU
18
  # Define the chatbot function
 
51
  txt.submit(chatbot, [txt, chatbot], [chatbot, txt])
52
 
53
  # Launch the Gradio interface
54
+ demo.launch()