File size: 1,679 Bytes
b9cf68a
 
bf07e8f
 
4d35d17
b7b3996
b9cf68a
bf07e8f
b9cf68a
 
974dd09
b9cf68a
bf07e8f
b9cf68a
bf07e8f
b9cf68a
 
4d35d17
bf07e8f
d3fdb20
974dd09
4d35d17
b9cf68a
974dd09
bf07e8f
 
 
b9cf68a
bf07e8f
 
d3fdb20
b9cf68a
 
 
 
bf07e8f
b9cf68a
 
 
 
974dd09
b9cf68a
 
bf07e8f
b9cf68a
974dd09
b9cf68a
 
 
bf07e8f
 
b9cf68a
 
bf07e8f
b9cf68a
 
 
d3fdb20
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
import pandas as pd
from datasets import Dataset
from transformers import AutoTokenizer, TrainingArguments, Trainer, AutoModelForCausalLM
import torch
import os


def train_model(file, hf_token):
    try:
        # Basic data loading
        df = pd.read_csv(file.name)
        print(f"Loaded CSV with {len(df)} rows")
        
        # Load tokenizer and model
        model_name = "facebook/opt-125m"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            low_cpu_mem_usage=True,  # Lower memory usage
            torch_dtype=torch.float32  # Ensure compatibility with CPU
        )
        
        # Prepare dataset
        dataset = Dataset.from_pandas(df)
        
        args = TrainingArguments(
            output_dir="./results",
            per_device_train_batch_size=1,
            num_train_epochs=1,
            no_cuda=True,  # Explicitly disable GPU
        )
        
        trainer = Trainer(
            model=model,
            args=args,
            train_dataset=dataset,
            tokenizer=tokenizer
        )
        
        return f"Setup successful! Loaded {len(df)} rows for training."
        
    except Exception as e:
        return f"Error: {str(e)}\nType: {type(e)}"

# Gradio interface
demo = gr.Interface(
    fn=train_model,
    inputs=[
        gr.File(label="Upload CSV file"),
        gr.Textbox(label="HF Token", type="password")
    ],
    outputs="text",
    title="Product Classifier Training (CPU)",
)

if __name__ == "__main__":
    demo.launch(debug=True, share=True)  # Enable public link for easier testing