spdin commited on
Commit
333cd19
·
1 Parent(s): 0da7162

initial commit

Browse files
Files changed (6) hide show
  1. app.py +33 -0
  2. model.py +47 -0
  3. prediction.py +48 -0
  4. training.py +74 -0
  5. utils.py +12 -0
  6. validation.py +65 -0
app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import streamlit as st
3
+
4
+ import training
5
+ import validation
6
+ import prediction
7
+
8
+
9
+ # Session initialization
10
+ if "key" not in st.session_state:
11
+ st.session_state["key"] = str(uuid.uuid4()).split("-")[-1]
12
+
13
+
14
+ def training_page():
15
+ training.main()
16
+
17
+
18
+ def validation_page():
19
+ validation.main()
20
+
21
+
22
+ def prediction_page():
23
+ prediction.main()
24
+
25
+
26
+ page_names_to_funcs = {
27
+ "Training": training_page,
28
+ "Validation": validation_page,
29
+ "Prediction": prediction_page,
30
+ }
31
+
32
+ selected_page = st.sidebar.selectbox("Select a page", page_names_to_funcs.keys())
33
+ page_names_to_funcs[selected_page]()
model.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setfit import SetFitModel, SetFitTrainer
2
+ from sentence_transformers.losses import CosineSimilarityLoss
3
+
4
+
5
+ # Function to create a pipeline for text classification using the trained model
6
+ def create_classifier(model_path):
7
+ classifier = SetFitModel.from_pretrained(
8
+ model_path,
9
+ local_files_only=True,
10
+ )
11
+ return classifier
12
+
13
+
14
+ def run_setfit_training(
15
+ session_id, model_id, model_name, train_dataset, batch_size, num_iterations
16
+ ):
17
+
18
+ model = SetFitModel.from_pretrained(model_id)
19
+
20
+ # Create trainer
21
+ trainer = SetFitTrainer(
22
+ model=model,
23
+ train_dataset=train_dataset,
24
+ eval_dataset=train_dataset,
25
+ loss_class=CosineSimilarityLoss,
26
+ metric="accuracy",
27
+ batch_size=batch_size,
28
+ num_iterations=num_iterations, # The number of text pairs to generate for contrastive learning
29
+ num_epochs=1, # The number of epochs to use for constrastive learning
30
+ column_mapping={"text": "text", "label": "label"},
31
+ )
32
+
33
+ trainer.train()
34
+ # metrics = trainer.evaluate()
35
+ # accuracy = metrics["accuracy"]
36
+
37
+ print(f"model used: {model_id}")
38
+ print(f"train dataset: {len(train_dataset)} samples")
39
+ # print(f"accuracy: {accuracy}")
40
+
41
+ save_model_path = f"./models/{session_id}/{model_id}_{model_name}"
42
+
43
+ trainer.model._save_pretrained(
44
+ save_directory=f"./models/{session_id}/{model_id}_{model_name}"
45
+ )
46
+
47
+ return save_model_path
prediction.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+
4
+ import streamlit as st
5
+ import model
6
+
7
+
8
+ def main():
9
+ st.title("Model Prediction")
10
+
11
+ st.write(f"Session ID: {st.session_state.key}")
12
+ session_id = st.session_state.key
13
+
14
+ if not os.path.isdir(f"models/{session_id}"):
15
+ st.write("Model is not available")
16
+ st.stop()
17
+
18
+ model_options = [model_name for model_name in os.listdir(f"models/{session_id}")]
19
+
20
+ models = {
21
+ model_name: os.path.abspath(os.path.join(f"models/{session_id}", model_name))
22
+ for model_name in model_options
23
+ }
24
+
25
+ model_name = st.selectbox("Select a model", options=model_options)
26
+
27
+ # Text input
28
+ text = st.text_area("Enter some text here", height=200)
29
+
30
+ # Prediction button
31
+ if st.button("Predict"):
32
+
33
+ with open(f"{models[model_name]}/label.pkl", "rb") as f:
34
+ label_map = pickle.load(f)
35
+
36
+ classifier = model.create_classifier(models[model_name])
37
+
38
+ prediction = classifier([text])
39
+
40
+ prediction_class = prediction[0].item()
41
+
42
+ confidence_score = classifier.predict_proba([text])[0][prediction_class].item()
43
+
44
+ st.write(
45
+ "The predicted label is:",
46
+ label_map[prediction_class],
47
+ f"{round(confidence_score*100,2)}%",
48
+ )
training.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import pandas as pd
3
+ import streamlit as st
4
+
5
+ from datasets import Dataset
6
+
7
+ import model
8
+ from utils import check_columns, count_labels
9
+
10
+ # Main function to run the Streamlit app
11
+ def main():
12
+ # Set app title
13
+ st.title("Few Shot Learning Demo using SetFit")
14
+
15
+ # Display the session ID
16
+ st.write(f"Session ID: {st.session_state.key}")
17
+ session_id = st.session_state.key
18
+
19
+ # Create file uploader
20
+ uploaded_file = st.file_uploader("Choose a CSV file to upload", type="csv")
21
+
22
+ # Check if file was uploaded
23
+ if uploaded_file is not None:
24
+ # Read CSV file into pandas DataFrame
25
+ df = pd.read_csv(uploaded_file)
26
+
27
+ # Check if DataFrame has expected columns
28
+ if check_columns(df):
29
+ # Display DataFrame as a table
30
+ st.write(df)
31
+
32
+ # Calculate the number of instances of each label class
33
+ label_counts = count_labels(df)
34
+ st.write(f"Number of instances of each label class: {label_counts}")
35
+
36
+ labels = set(df["label"].tolist())
37
+ label_map = {label: idx for idx, label in enumerate(labels)}
38
+
39
+ df["label"] = df["label"].map(label_map)
40
+
41
+ dataset = Dataset.from_pandas(df)
42
+
43
+ model_name = st.text_input("Input the model name")
44
+
45
+ pretrained_model_options = ["all-MiniLM-L6-v2", "paraphrase-MiniLM-L3-v2"]
46
+
47
+ pretrained_model = st.selectbox(
48
+ "Select a pretrained model", options=pretrained_model_options
49
+ )
50
+
51
+ # Add Train button
52
+ if st.button("Train"):
53
+ # Train the model
54
+ with st.spinner("Training model..."):
55
+ model_path = model.run_setfit_training(
56
+ session_id,
57
+ pretrained_model,
58
+ model_name,
59
+ dataset,
60
+ 1,
61
+ 10,
62
+ )
63
+
64
+ st.write(f"Model checkpoint saved {model_path.split('/')[-1]}")
65
+
66
+ label_map = {v: k for k, v in label_map.items()}
67
+ with open(f"{model_path}/label.pkl", "wb") as f:
68
+ pickle.dump(label_map, f)
69
+
70
+ st.write("Training Finished")
71
+ st.write("Go to Validation Page")
72
+
73
+ else:
74
+ st.error("File must have 'text' and 'label' columns.")
utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Function to check if the uploaded file has the expected columns
2
+ def check_columns(df):
3
+ if set(df.columns) == set(["text", "label"]):
4
+ return True
5
+ else:
6
+ return False
7
+
8
+
9
+ # Function to calculate the number of instances of each label class
10
+ def count_labels(df):
11
+ counts = df["label"].value_counts()
12
+ return counts.to_dict()
validation.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import pandas as pd
4
+ import streamlit as st
5
+
6
+ import model
7
+
8
+ from utils import check_columns
9
+
10
+
11
+ # Function to validate the trained model with a new uploaded CSV file
12
+ def main():
13
+
14
+ st.title("Model Validation")
15
+
16
+ # Display the session ID
17
+ st.write(f"Session ID: {st.session_state.key}")
18
+ session_id = st.session_state.key
19
+
20
+ if not os.path.isdir(f"models/{session_id}"):
21
+ st.write("Model is not available")
22
+ st.stop()
23
+
24
+ model_options = [model_name for model_name in os.listdir(f"models/{session_id}")]
25
+
26
+ models = {
27
+ model_name: os.path.abspath(os.path.join(f"models/{session_id}", model_name))
28
+ for model_name in model_options
29
+ }
30
+
31
+ model_name = st.selectbox("Select a model", options=model_options)
32
+
33
+ # Create file uploader for validation CSV file
34
+ validation_file = st.file_uploader(
35
+ "Choose a CSV file to validate the model", type="csv"
36
+ )
37
+
38
+ # Check if validation file was uploaded
39
+ if validation_file is not None:
40
+ # Read CSV file into pandas DataFrame
41
+ validation_df = pd.read_csv(validation_file)
42
+
43
+ # Check if DataFrame has expected columns
44
+ if check_columns(validation_df):
45
+ # Display DataFrame as a table
46
+ st.write(validation_df)
47
+
48
+ # Create pipeline for text classification using the trained model
49
+ classifier = model.create_classifier(models[model_name])
50
+
51
+ with open(f"{models[model_name]}/label.pkl", "rb") as f:
52
+ label_map = pickle.load(f)
53
+
54
+ results = classifier(validation_df["text"].tolist())
55
+
56
+ # Predict labels for validation DataFrame
57
+ validation_df["predicted_label"] = [
58
+ label_map[result.item()] for result in results
59
+ ]
60
+
61
+ # Display validation DataFrame with predicted labels
62
+ st.write("Validation results:")
63
+ st.write(validation_df)
64
+ else:
65
+ st.error("Validation file must have 'text' and 'label' columns.")