cheberle commited on
Commit
bf07e8f
·
1 Parent(s): 4d35d17
Files changed (2) hide show
  1. app.py +27 -41
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,74 +1,60 @@
1
  import gradio as gr
2
- from datasets import Dataset
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
4
  import pandas as pd
5
- from huggingface_hub import login
 
6
  import torch
7
 
 
 
 
8
  def train_model(file, hf_token):
9
  try:
10
- # Login to Hugging Face
11
- if not hf_token:
12
- return "Please provide a Hugging Face token"
13
- login(hf_token)
14
-
15
- # Load and prepare data
16
  df = pd.read_csv(file.name)
17
- dataset = Dataset.from_pandas(df)
18
 
19
- # Model setup - force CPU
20
  model_name = "facebook/opt-125m"
21
- device_map = "cpu" # Force CPU usage
22
  tokenizer = AutoTokenizer.from_pretrained(model_name)
23
  model = AutoModelForCausalLM.from_pretrained(
24
- model_name,
25
- device_map=device_map,
26
- torch_dtype=torch.float32 # Use float32 for CPU
27
  )
 
28
 
29
- # Training configuration
30
- training_args = TrainingArguments(
 
 
31
  output_dir="./results",
32
- num_train_epochs=3,
33
- per_device_train_batch_size=1, # Reduced for CPU
34
- learning_rate=3e-5,
35
- save_strategy="epoch",
36
- push_to_hub=True,
37
- hub_token=hf_token,
38
- no_cuda=True, # Force CPU usage
39
- report_to="none" # Disable wandb logging
40
  )
41
 
42
- # Initialize trainer
43
  trainer = Trainer(
44
  model=model,
45
- args=training_args,
46
  train_dataset=dataset,
47
  tokenizer=tokenizer
48
  )
49
 
50
- # Run training
51
- trainer.train()
52
-
53
- # Push to hub
54
- model.push_to_hub(f"cheberle/product-classifier-{pd.Timestamp.now().strftime('%Y%m%d')}")
55
-
56
- return "Training completed successfully!"
57
 
58
  except Exception as e:
59
- return f"Error occurred: {str(e)}"
60
 
61
- # Create Gradio interface
62
  demo = gr.Interface(
63
  fn=train_model,
64
  inputs=[
65
- gr.File(label="Upload your CSV file"),
66
- gr.Textbox(label="Hugging Face Token", type="password")
67
  ],
68
  outputs="text",
69
- title="Product Classifier Training",
70
- description="Upload your CSV data to train a product classifier model on CPU."
71
  )
72
 
73
  if __name__ == "__main__":
74
- demo.launch(share=False)
 
1
  import gradio as gr
 
 
2
  import pandas as pd
3
+ from datasets import Dataset
4
+ from transformers import AutoTokenizer, TrainingArguments, Trainer, AutoModelForCausalLM
5
  import torch
6
 
7
+ print("CUDA available:", torch.cuda.is_available())
8
+ print("Device:", torch.device('cpu'))
9
+
10
  def train_model(file, hf_token):
11
  try:
12
+ # Basic data loading test
 
 
 
 
 
13
  df = pd.read_csv(file.name)
14
+ print(f"Loaded CSV with {len(df)} rows")
15
 
16
+ # Load tokenizer and model
17
  model_name = "facebook/opt-125m"
 
18
  tokenizer = AutoTokenizer.from_pretrained(model_name)
19
  model = AutoModelForCausalLM.from_pretrained(
20
+ model_name,
21
+ device_map=None, # Force simple device mapping
22
+ low_cpu_mem_usage=True
23
  )
24
+ model = model.to('cpu') # Explicitly move to CPU
25
 
26
+ # Basic dataset creation
27
+ dataset = Dataset.from_pandas(df)
28
+
29
+ args = TrainingArguments(
30
  output_dir="./results",
31
+ per_device_train_batch_size=1,
32
+ num_train_epochs=1,
33
+ no_cuda=True,
34
+ local_rank=-1
 
 
 
 
35
  )
36
 
 
37
  trainer = Trainer(
38
  model=model,
39
+ args=args,
40
  train_dataset=dataset,
41
  tokenizer=tokenizer
42
  )
43
 
44
+ return f"Setup successful! Loaded {len(df)} rows"
 
 
 
 
 
 
45
 
46
  except Exception as e:
47
+ return f"Error: {str(e)}\nType: {type(e)}"
48
 
 
49
  demo = gr.Interface(
50
  fn=train_model,
51
  inputs=[
52
+ gr.File(label="Upload CSV file"),
53
+ gr.Textbox(label="HF Token", type="password")
54
  ],
55
  outputs="text",
56
+ title="Product Classifier Training (CPU)",
 
57
  )
58
 
59
  if __name__ == "__main__":
60
+ demo.launch(debug=True) # Enable debug mode
requirements.txt CHANGED
@@ -3,3 +3,4 @@ transformers==4.37.2
3
  torch==2.1.2
4
  datasets==2.16.1
5
  pandas==2.2.0
 
 
3
  torch==2.1.2
4
  datasets==2.16.1
5
  pandas==2.2.0
6
+ huggingface-hub==0.27.0