Spaces:
Build error
Build error
Commit
·
d655f51
1
Parent(s):
281995c
add train code and requirements text file...
Browse files- app.py +119 -3
- requirements.txt +3 -0
app.py
CHANGED
|
@@ -1,7 +1,123 @@
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
|
|
|
| 5 |
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
demo.launch()
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
import gradio as gr
|
| 3 |
|
| 4 |
+
# code
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
|
| 8 |
+
# from sentence_transformers import (
|
| 9 |
+
# SentenceTransformer,
|
| 10 |
+
# SentenceTransformerTrainer,
|
| 11 |
+
# SentenceTransformerTrainingArguments,
|
| 12 |
+
# SentenceTransformerModelCardData
|
| 13 |
+
# ) ### we can imporet everhtuing from the main class...
|
| 14 |
+
|
| 15 |
+
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer
|
| 16 |
+
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
|
| 17 |
+
from sentence_transformers.evaluation import InformationRetrievalEvaluator
|
| 18 |
+
from sentence_transformers.training_args import SentenceTransformerTrainingArguments, BatchSamplers
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_ir_evaluator(eval_ds):
|
| 24 |
+
"""create from anchor positive dataset instance... could make from a better dataset... LLM generate?"""
|
| 25 |
+
|
| 26 |
+
corpus = {}
|
| 27 |
+
queries = {}
|
| 28 |
+
relevant_docs = {} # relevant documents (qid => set[cid])
|
| 29 |
+
for idx, example in enumerate(eval_ds):
|
| 30 |
+
query = example['anchor']
|
| 31 |
+
queries[idx] = query
|
| 32 |
+
|
| 33 |
+
document = example['positive']
|
| 34 |
+
corpus[idx] = document
|
| 35 |
+
|
| 36 |
+
relevant_docs[idx] = set([idx]) # note: should have more relevant docs here
|
| 37 |
+
|
| 38 |
+
ir_evaluator = InformationRetrievalEvaluator(
|
| 39 |
+
queries=queries,
|
| 40 |
+
corpus=corpus,
|
| 41 |
+
relevant_docs=relevant_docs,
|
| 42 |
+
name="ir-evaluator",
|
| 43 |
+
)
|
| 44 |
+
return ir_evaluator
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@spaces.GPU(duration=3600)
|
| 50 |
+
def train(hf_token, dataset_id, model_id, num_epochs, dev):
|
| 51 |
+
|
| 52 |
+
ds = load_dataset(dataset_id, split="train", token=hf_token)
|
| 53 |
+
ds = ds.shuffle(seed=42)
|
| 54 |
+
if len(ds) > 1000 and dev: ds = ds.select(range(0, 999))
|
| 55 |
+
ds = ds.train_test_split(train_size=0.75)
|
| 56 |
+
train_ds, eval_ds = ds['train'], ds['test']
|
| 57 |
+
print('train: ', len(train_ds), 'eval: ', len(eval_ds))
|
| 58 |
+
|
| 59 |
+
# model
|
| 60 |
+
model = SentenceTransformer(model_id)
|
| 61 |
+
|
| 62 |
+
# loss
|
| 63 |
+
loss = CachedMultipleNegativesRankingLoss(model)
|
| 64 |
+
|
| 65 |
+
# training args
|
| 66 |
+
args = SentenceTransformerTrainingArguments(
|
| 67 |
+
output_dir="outputs", # required
|
| 68 |
+
num_train_epochs=num_epochs, # optional...
|
| 69 |
+
per_device_train_batch_size=16,
|
| 70 |
+
warmup_ratio=0.1,
|
| 71 |
+
#fp16=True, # Set to False if your GPU can't handle FP16
|
| 72 |
+
#bf16=False, # Set to True if your GPU supports BF16
|
| 73 |
+
batch_sampler=BatchSamplers.NO_DUPLICATES, # Losses using "in-batch negatives" benefit from no duplicates
|
| 74 |
+
save_total_limit=2
|
| 75 |
+
# per_device_eval_batch_size=1,
|
| 76 |
+
# eval_strategy="epoch",
|
| 77 |
+
# save_strategy="epoch",
|
| 78 |
+
# logging_steps=100,
|
| 79 |
+
# Optional tracking/debugging parameters:
|
| 80 |
+
# eval_strategy="steps",
|
| 81 |
+
# eval_steps=100,
|
| 82 |
+
# save_strategy="steps",
|
| 83 |
+
# save_steps=100,
|
| 84 |
+
# logging_steps=100,
|
| 85 |
+
# run_name="jina-code-vechain-pair", # Used in W&B if `wandb` is installed
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# ir evaluator
|
| 89 |
+
ir_evaluator = get_ir_evaluator(eval_ds)
|
| 90 |
+
|
| 91 |
+
# base model metrics
|
| 92 |
+
base_metrics = ir_evaluator(model)
|
| 93 |
+
print(ir_evaluator.primary_metric)
|
| 94 |
+
print(base_metrics[ir_evaluator.primary_metric])
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# train
|
| 98 |
+
trainer = SentenceTransformerTrainer(
|
| 99 |
+
model=model,
|
| 100 |
+
args=args,
|
| 101 |
+
train_dataset=train_ds,
|
| 102 |
+
# eval_dataset=eval_ds,
|
| 103 |
+
loss=loss,
|
| 104 |
+
# evaluator=ir_evaluator,
|
| 105 |
+
)
|
| 106 |
+
trainer.train()
|
| 107 |
+
|
| 108 |
+
# fine tuned model metrics
|
| 109 |
+
ft_metrics = ir_evaluator(model)
|
| 110 |
+
print(ir_evaluator.primary_metric)
|
| 111 |
+
print(ft_metrics[ir_evaluator.primary_metric])
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
metrics = pd.DataFrame([base_metrics, ft_metrics]).T
|
| 115 |
+
print(metrics)
|
| 116 |
+
return str(metrics)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
## logs to UI
|
| 120 |
+
# https://github.com/gradio-app/gradio/issues/2362#issuecomment-1424446778
|
| 121 |
+
|
| 122 |
+
demo = gr.Interface(fn=greet, inputs=["text", "text", "text", "number", "bool"], outputs=["text"]) # "dataframe"
|
| 123 |
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets
|
| 2 |
+
accelerate
|
| 3 |
+
sentence-transformers
|