Spaces:
Running
Running
#!/usr/bin/env python3 | |
import os | |
from torch.utils.data import DataLoader | |
from sentence_transformers import SentenceTransformer, InputExample, losses | |
# If you want to push to the HF Hub/Spaces programmatically: | |
# pip install huggingface_hub | |
# from huggingface_hub import HfApi, HfFolder | |
def main(): | |
#-------------------------------------------------------------------------- | |
# 1. (Optional) Setup your Hugging Face auth | |
#-------------------------------------------------------------------------- | |
# If you need to log into your HF account, you can do: | |
# hf_token = os.getenv("HF_TOKEN") # or read from a config file | |
# HfFolder.save_token(hf_token) | |
# api = HfApi() | |
# | |
# Then set something like: | |
# repo_id = "KolumbusLindh/my-weekly-model" | |
# | |
# Alternatively, you can push manually later via huggingface-cli. | |
#-------------------------------------------------------------------------- | |
# 2. Placeholder training data | |
#-------------------------------------------------------------------------- | |
# Suppose each tuple is: (CV_text, liked_job_text, disliked_job_text). | |
# In a real scenario, you'd gather user feedback from your database. | |
train_data = [ | |
("My CV #1", "Job #1 that user liked", "Job #1 that user disliked"), | |
("My CV #2", "Job #2 that user liked", "Job #2 that user disliked"), | |
# ... | |
] | |
#-------------------------------------------------------------------------- | |
# 3. Convert data into Sentence Transformers InputExamples | |
#-------------------------------------------------------------------------- | |
train_examples = [] | |
for (cv_text, liked_job_text, disliked_job_text) in train_data: | |
example = InputExample( | |
texts=[cv_text, liked_job_text, disliked_job_text] | |
# TripletLoss expects exactly 3 texts: anchor, positive, negative | |
) | |
train_examples.append(example) | |
#-------------------------------------------------------------------------- | |
# 4. Load the base model | |
#-------------------------------------------------------------------------- | |
model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" | |
model = SentenceTransformer(model_name) | |
#-------------------------------------------------------------------------- | |
# 5. Prepare DataLoader & define the Triplet Loss | |
#-------------------------------------------------------------------------- | |
# A typical margin is 0.5–1.0. Feel free to adjust it. | |
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=8) | |
train_loss = losses.TripletLoss( | |
model=model, | |
distance_metric=losses.TripletDistanceMetric.COSINE, | |
margin=0.5 | |
) | |
#-------------------------------------------------------------------------- | |
# 6. Fine-tune (fit) the model | |
#-------------------------------------------------------------------------- | |
# Just 1 epoch here for demo. In practice, tune #epochs/batch_size, etc. | |
num_epochs = 1 | |
warmup_steps = int(len(train_dataloader) * num_epochs * 0.1) # ~10% warmup | |
model.fit( | |
train_objectives=[(train_dataloader, train_loss)], | |
epochs=num_epochs, | |
warmup_steps=warmup_steps, | |
show_progress_bar=True | |
) | |
#-------------------------------------------------------------------------- | |
# 7. Save model locally | |
#-------------------------------------------------------------------------- | |
local_output_path = "my_finetuned_model" | |
model.save(local_output_path) | |
print(f"Model fine-tuned and saved locally to: {local_output_path}") | |
#-------------------------------------------------------------------------- | |
# 8. (Optional) Push to your Hugging Face Space | |
#-------------------------------------------------------------------------- | |
# If you want to push automatically: | |
# | |
# model.push_to_hub(repo_id=repo_id, commit_message="Weekly model update") | |
# | |
# Or if you have a Space at e.g. https://huggingface.co/spaces/KolumbusLindh/<some-name>, | |
# you’d create a repo on HF, then push to that repo. Typically one uses | |
# huggingface-cli or the huggingface_hub methods for that: | |
# | |
# api.create_repo(repo_id=repo_id, repo_type="model", private=False) | |
# model.push_to_hub(repo_id=repo_id) | |
# | |
# # If it's a Space, you might need to store your model in the "models" folder | |
# # or however your Gradio app is set up to load it. | |
if __name__ == "__main__": | |
main() | |