Akjava commited on
Commit
0ee8fa9
Β·
verified Β·
1 Parent(s): 81cfac3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -32
app.py CHANGED
@@ -6,38 +6,40 @@ import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
  import gradio as gr
8
 
9
-
10
- huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
11
- if not huggingface_token:
12
- pass
13
- print("no HUGGINGFACE_TOKEN if you need set secret ")
14
- #raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
15
-
16
- model_id = "google/gemma-2-9b-it"
17
-
18
- device = "auto" # torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
- dtype = torch.bfloat16
20
-
21
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=huggingface_token)
22
-
23
- print(model_id,device,dtype)
24
- histories = []
25
- #model = None
26
-
27
- model = AutoModelForCausalLM.from_pretrained(
28
- model_id, token=huggingface_token ,torch_dtype=dtype,device_map=device
29
- )
30
- text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer,torch_dtype=dtype,device_map=device) #pipeline has not to(device)
31
-
32
- if next(model.parameters()).is_cuda:
33
- print("The model is on a GPU")
34
- else:
35
- print("The model is on a CPU")
36
-
37
- if text_generator.device == 'cuda':
38
- print("The pipeline is using a GPU")
39
- else:
40
- print("The pipeline is using a CPU")
 
 
41
 
42
  @spaces.GPU(duration=120)
43
  def generate_text(messages):
@@ -78,4 +80,5 @@ def call_generate_text(message, history):
78
  demo = gr.ChatInterface(call_generate_text,type="messages")
79
 
80
  if __name__ == "__main__":
 
81
  demo.launch(share=True)
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
  import gradio as gr
8
 
9
+ def init():
10
+ huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
11
+ if not huggingface_token:
12
+ pass
13
+ print("no HUGGINGFACE_TOKEN if you need set secret ")
14
+ #raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
15
+
16
+ model_id = "google/gemma-2-9b-it"
17
+
18
+ device = "auto" # torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ dtype = torch.bfloat16
20
+
21
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=huggingface_token)
22
+
23
+ print(model_id,device,dtype)
24
+ histories = []
25
+ #model = None
26
+
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ model_id, token=huggingface_token ,torch_dtype=dtype,device_map=device
29
+ )
30
+ text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer,torch_dtype=dtype,device_map=device) #pipeline has not to(device)
31
+
32
+ if next(model.parameters()).is_cuda:
33
+ print("The model is on a GPU")
34
+ else:
35
+ print("The model is on a CPU")
36
+
37
+ if text_generator.device == 'cuda':
38
+ print("The pipeline is using a GPU")
39
+ else:
40
+ print("The pipeline is using a CPU")
41
+
42
+ print("initialized")
43
 
44
  @spaces.GPU(duration=120)
45
  def generate_text(messages):
 
80
  demo = gr.ChatInterface(call_generate_text,type="messages")
81
 
82
  if __name__ == "__main__":
83
+ init()
84
  demo.launch(share=True)