Spaces:
Sleeping
Sleeping
spdin
commited on
Commit
·
333cd19
1
Parent(s):
0da7162
initial commit
Browse files- app.py +33 -0
- model.py +47 -0
- prediction.py +48 -0
- training.py +74 -0
- utils.py +12 -0
- validation.py +65 -0
app.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uuid
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
import training
|
5 |
+
import validation
|
6 |
+
import prediction
|
7 |
+
|
8 |
+
|
9 |
+
# Session initialization
|
10 |
+
if "key" not in st.session_state:
|
11 |
+
st.session_state["key"] = str(uuid.uuid4()).split("-")[-1]
|
12 |
+
|
13 |
+
|
14 |
+
def training_page():
|
15 |
+
training.main()
|
16 |
+
|
17 |
+
|
18 |
+
def validation_page():
|
19 |
+
validation.main()
|
20 |
+
|
21 |
+
|
22 |
+
def prediction_page():
|
23 |
+
prediction.main()
|
24 |
+
|
25 |
+
|
26 |
+
page_names_to_funcs = {
|
27 |
+
"Training": training_page,
|
28 |
+
"Validation": validation_page,
|
29 |
+
"Prediction": prediction_page,
|
30 |
+
}
|
31 |
+
|
32 |
+
selected_page = st.sidebar.selectbox("Select a page", page_names_to_funcs.keys())
|
33 |
+
page_names_to_funcs[selected_page]()
|
model.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setfit import SetFitModel, SetFitTrainer
|
2 |
+
from sentence_transformers.losses import CosineSimilarityLoss
|
3 |
+
|
4 |
+
|
5 |
+
# Function to create a pipeline for text classification using the trained model
|
6 |
+
def create_classifier(model_path):
|
7 |
+
classifier = SetFitModel.from_pretrained(
|
8 |
+
model_path,
|
9 |
+
local_files_only=True,
|
10 |
+
)
|
11 |
+
return classifier
|
12 |
+
|
13 |
+
|
14 |
+
def run_setfit_training(
|
15 |
+
session_id, model_id, model_name, train_dataset, batch_size, num_iterations
|
16 |
+
):
|
17 |
+
|
18 |
+
model = SetFitModel.from_pretrained(model_id)
|
19 |
+
|
20 |
+
# Create trainer
|
21 |
+
trainer = SetFitTrainer(
|
22 |
+
model=model,
|
23 |
+
train_dataset=train_dataset,
|
24 |
+
eval_dataset=train_dataset,
|
25 |
+
loss_class=CosineSimilarityLoss,
|
26 |
+
metric="accuracy",
|
27 |
+
batch_size=batch_size,
|
28 |
+
num_iterations=num_iterations, # The number of text pairs to generate for contrastive learning
|
29 |
+
num_epochs=1, # The number of epochs to use for constrastive learning
|
30 |
+
column_mapping={"text": "text", "label": "label"},
|
31 |
+
)
|
32 |
+
|
33 |
+
trainer.train()
|
34 |
+
# metrics = trainer.evaluate()
|
35 |
+
# accuracy = metrics["accuracy"]
|
36 |
+
|
37 |
+
print(f"model used: {model_id}")
|
38 |
+
print(f"train dataset: {len(train_dataset)} samples")
|
39 |
+
# print(f"accuracy: {accuracy}")
|
40 |
+
|
41 |
+
save_model_path = f"./models/{session_id}/{model_id}_{model_name}"
|
42 |
+
|
43 |
+
trainer.model._save_pretrained(
|
44 |
+
save_directory=f"./models/{session_id}/{model_id}_{model_name}"
|
45 |
+
)
|
46 |
+
|
47 |
+
return save_model_path
|
prediction.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
import model
|
6 |
+
|
7 |
+
|
8 |
+
def main():
|
9 |
+
st.title("Model Prediction")
|
10 |
+
|
11 |
+
st.write(f"Session ID: {st.session_state.key}")
|
12 |
+
session_id = st.session_state.key
|
13 |
+
|
14 |
+
if not os.path.isdir(f"models/{session_id}"):
|
15 |
+
st.write("Model is not available")
|
16 |
+
st.stop()
|
17 |
+
|
18 |
+
model_options = [model_name for model_name in os.listdir(f"models/{session_id}")]
|
19 |
+
|
20 |
+
models = {
|
21 |
+
model_name: os.path.abspath(os.path.join(f"models/{session_id}", model_name))
|
22 |
+
for model_name in model_options
|
23 |
+
}
|
24 |
+
|
25 |
+
model_name = st.selectbox("Select a model", options=model_options)
|
26 |
+
|
27 |
+
# Text input
|
28 |
+
text = st.text_area("Enter some text here", height=200)
|
29 |
+
|
30 |
+
# Prediction button
|
31 |
+
if st.button("Predict"):
|
32 |
+
|
33 |
+
with open(f"{models[model_name]}/label.pkl", "rb") as f:
|
34 |
+
label_map = pickle.load(f)
|
35 |
+
|
36 |
+
classifier = model.create_classifier(models[model_name])
|
37 |
+
|
38 |
+
prediction = classifier([text])
|
39 |
+
|
40 |
+
prediction_class = prediction[0].item()
|
41 |
+
|
42 |
+
confidence_score = classifier.predict_proba([text])[0][prediction_class].item()
|
43 |
+
|
44 |
+
st.write(
|
45 |
+
"The predicted label is:",
|
46 |
+
label_map[prediction_class],
|
47 |
+
f"{round(confidence_score*100,2)}%",
|
48 |
+
)
|
training.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import pandas as pd
|
3 |
+
import streamlit as st
|
4 |
+
|
5 |
+
from datasets import Dataset
|
6 |
+
|
7 |
+
import model
|
8 |
+
from utils import check_columns, count_labels
|
9 |
+
|
10 |
+
# Main function to run the Streamlit app
|
11 |
+
def main():
|
12 |
+
# Set app title
|
13 |
+
st.title("Few Shot Learning Demo using SetFit")
|
14 |
+
|
15 |
+
# Display the session ID
|
16 |
+
st.write(f"Session ID: {st.session_state.key}")
|
17 |
+
session_id = st.session_state.key
|
18 |
+
|
19 |
+
# Create file uploader
|
20 |
+
uploaded_file = st.file_uploader("Choose a CSV file to upload", type="csv")
|
21 |
+
|
22 |
+
# Check if file was uploaded
|
23 |
+
if uploaded_file is not None:
|
24 |
+
# Read CSV file into pandas DataFrame
|
25 |
+
df = pd.read_csv(uploaded_file)
|
26 |
+
|
27 |
+
# Check if DataFrame has expected columns
|
28 |
+
if check_columns(df):
|
29 |
+
# Display DataFrame as a table
|
30 |
+
st.write(df)
|
31 |
+
|
32 |
+
# Calculate the number of instances of each label class
|
33 |
+
label_counts = count_labels(df)
|
34 |
+
st.write(f"Number of instances of each label class: {label_counts}")
|
35 |
+
|
36 |
+
labels = set(df["label"].tolist())
|
37 |
+
label_map = {label: idx for idx, label in enumerate(labels)}
|
38 |
+
|
39 |
+
df["label"] = df["label"].map(label_map)
|
40 |
+
|
41 |
+
dataset = Dataset.from_pandas(df)
|
42 |
+
|
43 |
+
model_name = st.text_input("Input the model name")
|
44 |
+
|
45 |
+
pretrained_model_options = ["all-MiniLM-L6-v2", "paraphrase-MiniLM-L3-v2"]
|
46 |
+
|
47 |
+
pretrained_model = st.selectbox(
|
48 |
+
"Select a pretrained model", options=pretrained_model_options
|
49 |
+
)
|
50 |
+
|
51 |
+
# Add Train button
|
52 |
+
if st.button("Train"):
|
53 |
+
# Train the model
|
54 |
+
with st.spinner("Training model..."):
|
55 |
+
model_path = model.run_setfit_training(
|
56 |
+
session_id,
|
57 |
+
pretrained_model,
|
58 |
+
model_name,
|
59 |
+
dataset,
|
60 |
+
1,
|
61 |
+
10,
|
62 |
+
)
|
63 |
+
|
64 |
+
st.write(f"Model checkpoint saved {model_path.split('/')[-1]}")
|
65 |
+
|
66 |
+
label_map = {v: k for k, v in label_map.items()}
|
67 |
+
with open(f"{model_path}/label.pkl", "wb") as f:
|
68 |
+
pickle.dump(label_map, f)
|
69 |
+
|
70 |
+
st.write("Training Finished")
|
71 |
+
st.write("Go to Validation Page")
|
72 |
+
|
73 |
+
else:
|
74 |
+
st.error("File must have 'text' and 'label' columns.")
|
utils.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Function to check if the uploaded file has the expected columns
|
2 |
+
def check_columns(df):
|
3 |
+
if set(df.columns) == set(["text", "label"]):
|
4 |
+
return True
|
5 |
+
else:
|
6 |
+
return False
|
7 |
+
|
8 |
+
|
9 |
+
# Function to calculate the number of instances of each label class
|
10 |
+
def count_labels(df):
|
11 |
+
counts = df["label"].value_counts()
|
12 |
+
return counts.to_dict()
|
validation.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
import pandas as pd
|
4 |
+
import streamlit as st
|
5 |
+
|
6 |
+
import model
|
7 |
+
|
8 |
+
from utils import check_columns
|
9 |
+
|
10 |
+
|
11 |
+
# Function to validate the trained model with a new uploaded CSV file
|
12 |
+
def main():
|
13 |
+
|
14 |
+
st.title("Model Validation")
|
15 |
+
|
16 |
+
# Display the session ID
|
17 |
+
st.write(f"Session ID: {st.session_state.key}")
|
18 |
+
session_id = st.session_state.key
|
19 |
+
|
20 |
+
if not os.path.isdir(f"models/{session_id}"):
|
21 |
+
st.write("Model is not available")
|
22 |
+
st.stop()
|
23 |
+
|
24 |
+
model_options = [model_name for model_name in os.listdir(f"models/{session_id}")]
|
25 |
+
|
26 |
+
models = {
|
27 |
+
model_name: os.path.abspath(os.path.join(f"models/{session_id}", model_name))
|
28 |
+
for model_name in model_options
|
29 |
+
}
|
30 |
+
|
31 |
+
model_name = st.selectbox("Select a model", options=model_options)
|
32 |
+
|
33 |
+
# Create file uploader for validation CSV file
|
34 |
+
validation_file = st.file_uploader(
|
35 |
+
"Choose a CSV file to validate the model", type="csv"
|
36 |
+
)
|
37 |
+
|
38 |
+
# Check if validation file was uploaded
|
39 |
+
if validation_file is not None:
|
40 |
+
# Read CSV file into pandas DataFrame
|
41 |
+
validation_df = pd.read_csv(validation_file)
|
42 |
+
|
43 |
+
# Check if DataFrame has expected columns
|
44 |
+
if check_columns(validation_df):
|
45 |
+
# Display DataFrame as a table
|
46 |
+
st.write(validation_df)
|
47 |
+
|
48 |
+
# Create pipeline for text classification using the trained model
|
49 |
+
classifier = model.create_classifier(models[model_name])
|
50 |
+
|
51 |
+
with open(f"{models[model_name]}/label.pkl", "rb") as f:
|
52 |
+
label_map = pickle.load(f)
|
53 |
+
|
54 |
+
results = classifier(validation_df["text"].tolist())
|
55 |
+
|
56 |
+
# Predict labels for validation DataFrame
|
57 |
+
validation_df["predicted_label"] = [
|
58 |
+
label_map[result.item()] for result in results
|
59 |
+
]
|
60 |
+
|
61 |
+
# Display validation DataFrame with predicted labels
|
62 |
+
st.write("Validation results:")
|
63 |
+
st.write(validation_df)
|
64 |
+
else:
|
65 |
+
st.error("Validation file must have 'text' and 'label' columns.")
|