qgallouedec HF Staff commited on
Commit
04db925
·
1 Parent(s): 0acfbe1

try another home

Browse files
Files changed (2) hide show
  1. Dockerfile +4 -3
  2. sft.py +11 -2
Dockerfile CHANGED
@@ -2,7 +2,8 @@
2
  FROM pytorch/pytorch:2.7.0-cuda12.8-cudnn9-devel
3
 
4
  # Set working directory
5
- WORKDIR /usr/src/app
 
6
 
7
  # Install system dependencies, then clean up
8
  RUN apt-get update && \
@@ -21,10 +22,10 @@ RUN pip install --upgrade pip && \
21
  RUN install -m 755 sft.py /usr/local/bin/sft
22
 
23
  # Set environment variable for Hugging Face cache
24
- ENV HF_HOME=/usr/src/app/.cache/huggingface
25
 
26
  # Make /usr/src/app/ a writable directory
27
- RUN chmod -R 777 /usr/src/app
28
 
29
  # Expose the Gradio port
30
  EXPOSE 7860
 
2
  FROM pytorch/pytorch:2.7.0-cuda12.8-cudnn9-devel
3
 
4
  # Set working directory
5
+ ENV HOME=/usr/src/app
6
+ WORKDIR $HOME
7
 
8
  # Install system dependencies, then clean up
9
  RUN apt-get update && \
 
22
  RUN install -m 755 sft.py /usr/local/bin/sft
23
 
24
  # Set environment variable for Hugging Face cache
25
+ # ENV HF_HOME=/usr/src/app/.cache/huggingface
26
 
27
  # Make /usr/src/app/ a writable directory
28
+ RUN chmod -R 777 $HOME
29
 
30
  # Expose the Gradio port
31
  EXPOSE 7860
sft.py CHANGED
@@ -1,15 +1,24 @@
1
  #!/usr/bin/env python3
2
  import argparse
3
  import subprocess
 
4
 
5
  parser = argparse.ArgumentParser(description="Demo script for the model.")
6
 
7
  parser.add_argument("--model", type=str)
8
  parser.add_argument("--dataset", type=str)
9
  parser.add_argument("--flavor", type=str)
 
10
 
 
11
  args = parser.parse_args()
12
- print(f"Model: {args.model}, Dataset: {args.dataset}, Flavor: {args.flavor}")
 
 
 
 
 
 
13
 
14
  # Run another Python script
15
- subprocess.run(["trl", "sft", "--config", "configs/Qwen3-4B-Base_a10g-small.yaml"])
 
1
  #!/usr/bin/env python3
2
  import argparse
3
  import subprocess
4
+ import os
5
 
6
  parser = argparse.ArgumentParser(description="Demo script for the model.")
7
 
8
  parser.add_argument("--model", type=str)
9
  parser.add_argument("--dataset", type=str)
10
  parser.add_argument("--flavor", type=str)
11
+ parser.add_argument("--token", type=str)
12
 
13
+ # Parse the command line arguments
14
  args = parser.parse_args()
15
+
16
+ # Get the config file based on the model and flavor
17
+ config_file = f"configs/{args.model}_{args.flavor}.yaml"
18
+
19
+ # Check if the config file exists
20
+ if not os.path.exists(config_file):
21
+ raise RuntimeError(f"Training model {args.model} with flavor {args.flavor} is not supported.")
22
 
23
  # Run another Python script
24
+ subprocess.run(["trl", "sft", "--config", config_file, "--dataset_name", args.dataset])