huninEye commited on
Commit
b9eaad3
·
verified ·
1 Parent(s): 67cae53

Create Train_n_save.py

Browse files
Files changed (1) hide show
  1. Train_n_save.py +50 -0
Train_n_save.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
3
+ import pandas as pd
4
+ from neural_network import create_neural_network_model # Import your custom model function
5
+
6
+ # Load training data from CSV using pandas
7
+ training_data = pd.read_csv('training_data.csv')
8
+
9
+ # Assuming your training data has columns like 'features', 'labels', and 'rewards'
10
+ train_inputs = training_data['features'] # Replace 'features' with the actual column name
11
+ train_labels = training_data['labels'] # Replace 'labels' with the actual column name
12
+ train_rewards = training_data['rewards'] # Replace 'rewards' with the actual column name
13
+
14
+ # Define your model architecture
15
+ seq_length = 128 # Example sequence length
16
+ d_model = 512 # Example dimension
17
+ action_space_size = 10 # Example action space size
18
+
19
+ model = create_neural_network_model(seq_length, d_model, action_space_size)
20
+
21
+ # Define loss functions and metrics
22
+ losses = {'Output': 'categorical_crossentropy', 'Reward': 'mean_squared_error'}
23
+ metrics = {'Output': 'accuracy'}
24
+
25
+ # Compile the model
26
+ opt = tf.keras.optimizers.Adam(learning_rate=0.001)
27
+ model.compile(optimizer=opt, loss=losses, metrics=metrics)
28
+
29
+ # Define callbacks (e.g., ModelCheckpoint, EarlyStopping) as needed
30
+ callbacks = [
31
+ ModelCheckpoint(filepath='model_weights.h5', save_best_only=True),
32
+ EarlyStopping(patience=5, restore_best_weights=True)
33
+ ]
34
+
35
+ # Train the model
36
+ history = model.fit(
37
+ x=train_inputs, # Your training data
38
+ y={'Output': train_labels, 'Reward': train_rewards}, # Your training labels and rewards
39
+ batch_size=32,
40
+ epochs=50,
41
+ callbacks=callbacks
42
+ )
43
+
44
+ # Save the trained model
45
+ model.save('custom_model.h5')
46
+
47
+ # You can also save training history for analysis and plotting
48
+ import pickle
49
+ with open('training_history.pickle', 'wb') as history_file:
50
+ pickle.dump(history.history, history_file)