wuhp commited on
Commit
eabbd4b
·
verified ·
1 Parent(s): 5a9af80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -4
app.py CHANGED
@@ -2,34 +2,53 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
 
 
 
 
5
  MODEL_REPO = "wuhp/myr1"
6
  SUBFOLDER = "myr1"
7
 
 
 
 
 
8
  tokenizer = AutoTokenizer.from_pretrained(
9
  MODEL_REPO,
10
  subfolder=SUBFOLDER,
11
  trust_remote_code=True
12
  )
13
 
14
- # If your GPU has <24GB VRAM, consider 8-bit or CPU offloading
 
 
 
 
15
  model = AutoModelForCausalLM.from_pretrained(
16
  MODEL_REPO,
17
  subfolder=SUBFOLDER,
18
  trust_remote_code=True,
19
- device_map="auto", # tries to place layers on GPU, then CPU if needed
20
- torch_dtype=torch.float16, # or bfloat16 or float32
21
  low_cpu_mem_usage=True
22
  )
23
 
 
24
  model.eval()
25
 
 
 
 
26
  def generate_text(prompt, max_length=64, temperature=0.7, top_p=0.9):
27
  print("=== Starting generation ===")
 
28
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
29
  try:
 
30
  output_ids = model.generate(
31
  **inputs,
32
- max_new_tokens=max_length, # alternative to max_length
33
  temperature=temperature,
34
  top_p=top_p,
35
  do_sample=True,
@@ -39,8 +58,13 @@ def generate_text(prompt, max_length=64, temperature=0.7, top_p=0.9):
39
  except Exception as e:
40
  print(f"Error during generation: {e}")
41
  return str(e)
 
 
42
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
43
 
 
 
 
44
  demo = gr.Interface(
45
  fn=generate_text,
46
  inputs=[
@@ -58,5 +82,8 @@ demo = gr.Interface(
58
  description="Generates text using the large DeepSeek model."
59
  )
60
 
 
 
 
61
  if __name__ == "__main__":
62
  demo.launch()
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
+ # ----------------------------------------------------------------
6
+ # 1) Points to your Hugging Face repo and subfolder
7
+ # (where config.json, tokenizer.json, model safetensors, etc. reside).
8
+ # ----------------------------------------------------------------
9
  MODEL_REPO = "wuhp/myr1"
10
  SUBFOLDER = "myr1"
11
 
12
+ # ----------------------------------------------------------------
13
+ # 2) Load the tokenizer
14
+ # trust_remote_code=True allows custom code (e.g., DeepSeek config/classes).
15
+ # ----------------------------------------------------------------
16
  tokenizer = AutoTokenizer.from_pretrained(
17
  MODEL_REPO,
18
  subfolder=SUBFOLDER,
19
  trust_remote_code=True
20
  )
21
 
22
+ # ----------------------------------------------------------------
23
+ # 3) Load the model
24
+ # - device_map="auto" tries to place layers on GPU and offload remainder to CPU if needed
25
+ # - torch_dtype can be float16, float32, bfloat16, etc., depending on GPU support
26
+ # ----------------------------------------------------------------
27
  model = AutoModelForCausalLM.from_pretrained(
28
  MODEL_REPO,
29
  subfolder=SUBFOLDER,
30
  trust_remote_code=True,
31
+ device_map="auto",
32
+ torch_dtype=torch.float16,
33
  low_cpu_mem_usage=True
34
  )
35
 
36
+ # Put model in evaluation mode
37
  model.eval()
38
 
39
+ # ----------------------------------------------------------------
40
+ # 4) Define the generation function
41
+ # ----------------------------------------------------------------
42
  def generate_text(prompt, max_length=64, temperature=0.7, top_p=0.9):
43
  print("=== Starting generation ===")
44
+ # Move input tokens to the same device as model
45
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
46
+
47
  try:
48
+ # Generate tokens
49
  output_ids = model.generate(
50
  **inputs,
51
+ max_new_tokens=max_length, # This controls how many tokens beyond the prompt are generated
52
  temperature=temperature,
53
  top_p=top_p,
54
  do_sample=True,
 
58
  except Exception as e:
59
  print(f"Error during generation: {e}")
60
  return str(e)
61
+
62
+ # Decode back to text (skipping special tokens)
63
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
64
 
65
+ # ----------------------------------------------------------------
66
+ # 5) Build a Gradio UI
67
+ # ----------------------------------------------------------------
68
  demo = gr.Interface(
69
  fn=generate_text,
70
  inputs=[
 
82
  description="Generates text using the large DeepSeek model."
83
  )
84
 
85
+ # ----------------------------------------------------------------
86
+ # 6) Run the Gradio app
87
+ # ----------------------------------------------------------------
88
  if __name__ == "__main__":
89
  demo.launch()