cheberle commited on
Commit
b9cf68a
·
1 Parent(s): 02cd362
Files changed (2) hide show
  1. app.py +64 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from datasets import Dataset
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
5
+ import pandas as pd
6
+ from huggingface_hub import login
7
+
8
+ def train_model(file, hf_token):
9
+ try:
10
+ # Login to Hugging Face
11
+ if not hf_token:
12
+ return "Please provide a Hugging Face token"
13
+ login(hf_token)
14
+
15
+ # Load and prepare data
16
+ df = pd.read_csv(file.name)
17
+ dataset = Dataset.from_pandas(df)
18
+
19
+ # Model setup
20
+ model_name = "facebook/opt-125m"
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ model = AutoModelForCausalLM.from_pretrained(model_name)
23
+
24
+ # Training configuration
25
+ training_args = TrainingArguments(
26
+ output_dir="./results",
27
+ num_train_epochs=3,
28
+ per_device_train_batch_size=2,
29
+ learning_rate=3e-5,
30
+ save_strategy="epoch",
31
+ push_to_hub=True,
32
+ hub_token=hf_token
33
+ )
34
+
35
+ # Initialize trainer
36
+ trainer = Trainer(
37
+ model=model,
38
+ args=training_args,
39
+ train_dataset=dataset,
40
+ tokenizer=tokenizer
41
+ )
42
+
43
+ # Run training
44
+ trainer.train()
45
+
46
+ return "Training completed successfully!"
47
+
48
+ except Exception as e:
49
+ return f"Error occurred: {str(e)}"
50
+
51
+ # Create Gradio interface
52
+ demo = gr.Interface(
53
+ fn=train_model,
54
+ inputs=[
55
+ gr.File(label="Upload your CSV file"),
56
+ gr.Textbox(label="Hugging Face Token", type="password")
57
+ ],
58
+ outputs="text",
59
+ title="Product Classifier Training",
60
+ description="Upload your CSV data to train a product classifier model."
61
+ )
62
+
63
+ if __name__ == "__main__":
64
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio==4.19.2
2
+ transformers==4.37.2
3
+ torch==2.1.2
4
+ datasets==2.16.1
5
+ pandas==2.2.0