tomaarsen's picture
tomaarsen HF staff
Add auto-generated README
fdb2c71 verified
metadata
language:
  - en
tags:
  - sentence-transformers
  - cross-encoder
  - text-classification
  - generated_from_trainer
  - dataset_size:404290
  - loss:BinaryCrossEntropyLoss
base_model: distilbert/distilroberta-base
datasets:
  - sentence-transformers/quora-duplicates
pipeline_tag: text-classification
library_name: sentence-transformers
metrics:
  - accuracy
  - accuracy_threshold
  - f1
  - f1_threshold
  - precision
  - recall
  - average_precision
co2_eq_emissions:
  emissions: 26.889480385249758
  energy_consumed: 0.06917762292257246
  source: codecarbon
  training_type: fine-tuning
  on_cloud: false
  cpu_model: 13th Gen Intel(R) Core(TM) i7-13700K
  ram_total_size: 31.777088165283203
  hours_used: 0.214
  hardware_used: 1 x NVIDIA GeForce RTX 3090
model-index:
  - name: CrossEncoder based on distilbert/distilroberta-base
    results:
      - task:
          type: cross-encoder-classification
          name: Cross Encoder Classification
        dataset:
          name: quora duplicates dev
          type: quora-duplicates-dev
        metrics:
          - type: accuracy
            value: 0.8938
            name: Accuracy
          - type: accuracy_threshold
            value: 0.5088549852371216
            name: Accuracy Threshold
          - type: f1
            value: 0.8612281373675477
            name: F1
          - type: f1_threshold
            value: 0.3856155276298523
            name: F1 Threshold
          - type: precision
            value: 0.8182920912178554
            name: Precision
          - type: recall
            value: 0.908919428725411
            name: Recall
          - type: average_precision
            value: 0.920292628179356
            name: Average Precision
      - task:
          type: cross-encoder-classification
          name: Cross Encoder Classification
        dataset:
          name: quora duplicates test
          type: quora-duplicates-test
        metrics:
          - type: accuracy
            value: 0.8938
            name: Accuracy
          - type: accuracy_threshold
            value: 0.5091445446014404
            name: Accuracy Threshold
          - type: f1
            value: 0.8612281373675477
            name: F1
          - type: f1_threshold
            value: 0.38580775260925293
            name: F1 Threshold
          - type: precision
            value: 0.8182920912178554
            name: Precision
          - type: recall
            value: 0.908919428725411
            name: Recall
          - type: average_precision
            value: 0.92029239602284
            name: Average Precision

CrossEncoder based on distilbert/distilroberta-base

This is a Cross Encoder model finetuned from distilbert/distilroberta-base on the quora-duplicates dataset using the sentence-transformers library. It computes scores for pairs of texts, which can be used for semantic textual similarity, semantic search, paraphrase mining, text classification, clustering, and more.

Model Details

Model Description

Model Sources

Usage

Direct Usage (Sentence Transformers)

First install the Sentence Transformers library:

pip install -U sentence-transformers

Then you can load this model and run inference.

from sentence_transformers import CrossEncoder

# Download from the 🤗 Hub
model = CrossEncoder("sentence_transformers_model_id")
# Get scores for pairs...
pairs = [
    ['What is the step by step guide to invest in share market in india?', 'What is the step by step guide to invest in share market?'],
    ['What is the story of Kohinoor (Koh-i-Noor) Diamond?', 'What would happen if the Indian government stole the Kohinoor (Koh-i-Noor) diamond back?'],
    ['How can I increase the speed of my internet connection while using a VPN?', 'How can Internet speed be increased by hacking through DNS?'],
    ['Why am I mentally very lonely? How can I solve it?', 'Find the remainder when [math]23^{24}[/math] is divided by 24,23?'],
    ['Which one dissolve in water quikly sugar, salt, methane and carbon di oxide?', 'Which fish would survive in salt water?'],
]
scores = model.predict(pairs)
print(scores.shape)
# [5]

# ... or rank different texts based on similarity to a single text
ranks = model.rank(
    'What is the step by step guide to invest in share market in india?',
    [
        'What is the step by step guide to invest in share market?',
        'What would happen if the Indian government stole the Kohinoor (Koh-i-Noor) diamond back?',
        'How can Internet speed be increased by hacking through DNS?',
        'Find the remainder when [math]23^{24}[/math] is divided by 24,23?',
        'Which fish would survive in salt water?',
    ]
)
# [{'corpus_id': ..., 'score': ...}, {'corpus_id': ..., 'score': ...}, ...]

Evaluation

Metrics

Cross Encoder Classification

Metric quora-duplicates-dev quora-duplicates-test
accuracy 0.8938 0.8938
accuracy_threshold 0.5089 0.5091
f1 0.8612 0.8612
f1_threshold 0.3856 0.3858
precision 0.8183 0.8183
recall 0.9089 0.9089
average_precision 0.9203 0.9203

Training Details

Training Dataset

quora-duplicates

  • Dataset: quora-duplicates at 451a485
  • Size: 404,290 training samples
  • Columns: sentence1, sentence2, and label
  • Approximate statistics based on the first 1000 samples:
    sentence1 sentence2 label
    type string string int
    details
    • min: 1 characters
    • mean: 59.15 characters
    • max: 354 characters
    • min: 6 characters
    • mean: 60.74 characters
    • max: 399 characters
    • 0: ~64.20%
    • 1: ~35.80%
  • Samples:
    sentence1 sentence2 label
    What are the features of the Indian caste system? What triggers you the most when you play video games? 0
    What is the best place to learn Mandarin Chinese in Singapore? What is the best place in Singapore for durian in December? 0
    What will be Hillary Clinton's India policy if she wins the election? How would the bilateral relationship between India and the USA be under Hillary Clinton's presidency? 1
  • Loss: BinaryCrossEntropyLoss

Evaluation Dataset

quora-duplicates

  • Dataset: quora-duplicates at 451a485
  • Size: 404,290 evaluation samples
  • Columns: sentence1, sentence2, and label
  • Approximate statistics based on the first 1000 samples:
    sentence1 sentence2 label
    type string string int
    details
    • min: 11 characters
    • mean: 57.9 characters
    • max: 244 characters
    • min: 12 characters
    • mean: 59.33 characters
    • max: 221 characters
    • 0: ~62.00%
    • 1: ~38.00%
  • Samples:
    sentence1 sentence2 label
    What is the step by step guide to invest in share market in india? What is the step by step guide to invest in share market? 0
    What is the story of Kohinoor (Koh-i-Noor) Diamond? What would happen if the Indian government stole the Kohinoor (Koh-i-Noor) diamond back? 0
    How can I increase the speed of my internet connection while using a VPN? How can Internet speed be increased by hacking through DNS? 0
  • Loss: BinaryCrossEntropyLoss

Training Hyperparameters

Non-Default Hyperparameters

  • eval_strategy: steps
  • per_device_train_batch_size: 64
  • per_device_eval_batch_size: 64
  • num_train_epochs: 1
  • warmup_ratio: 0.1
  • bf16: True

All Hyperparameters

Click to expand
  • overwrite_output_dir: False
  • do_predict: False
  • eval_strategy: steps
  • prediction_loss_only: True
  • per_device_train_batch_size: 64
  • per_device_eval_batch_size: 64
  • per_gpu_train_batch_size: None
  • per_gpu_eval_batch_size: None
  • gradient_accumulation_steps: 1
  • eval_accumulation_steps: None
  • torch_empty_cache_steps: None
  • learning_rate: 5e-05
  • weight_decay: 0.0
  • adam_beta1: 0.9
  • adam_beta2: 0.999
  • adam_epsilon: 1e-08
  • max_grad_norm: 1.0
  • num_train_epochs: 1
  • max_steps: -1
  • lr_scheduler_type: linear
  • lr_scheduler_kwargs: {}
  • warmup_ratio: 0.1
  • warmup_steps: 0
  • log_level: passive
  • log_level_replica: warning
  • log_on_each_node: True
  • logging_nan_inf_filter: True
  • save_safetensors: True
  • save_on_each_node: False
  • save_only_model: False
  • restore_callback_states_from_checkpoint: False
  • no_cuda: False
  • use_cpu: False
  • use_mps_device: False
  • seed: 42
  • data_seed: None
  • jit_mode_eval: False
  • use_ipex: False
  • bf16: True
  • fp16: False
  • fp16_opt_level: O1
  • half_precision_backend: auto
  • bf16_full_eval: False
  • fp16_full_eval: False
  • tf32: None
  • local_rank: 0
  • ddp_backend: None
  • tpu_num_cores: None
  • tpu_metrics_debug: False
  • debug: []
  • dataloader_drop_last: False
  • dataloader_num_workers: 0
  • dataloader_prefetch_factor: None
  • past_index: -1
  • disable_tqdm: False
  • remove_unused_columns: True
  • label_names: None
  • load_best_model_at_end: False
  • ignore_data_skip: False
  • fsdp: []
  • fsdp_min_num_params: 0
  • fsdp_config: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}
  • fsdp_transformer_layer_cls_to_wrap: None
  • accelerator_config: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}
  • deepspeed: None
  • label_smoothing_factor: 0.0
  • optim: adamw_torch
  • optim_args: None
  • adafactor: False
  • group_by_length: False
  • length_column_name: length
  • ddp_find_unused_parameters: None
  • ddp_bucket_cap_mb: None
  • ddp_broadcast_buffers: False
  • dataloader_pin_memory: True
  • dataloader_persistent_workers: False
  • skip_memory_metrics: True
  • use_legacy_prediction_loop: False
  • push_to_hub: False
  • resume_from_checkpoint: None
  • hub_model_id: None
  • hub_strategy: every_save
  • hub_private_repo: None
  • hub_always_push: False
  • gradient_checkpointing: False
  • gradient_checkpointing_kwargs: None
  • include_inputs_for_metrics: False
  • include_for_metrics: []
  • eval_do_concat_batches: True
  • fp16_backend: auto
  • push_to_hub_model_id: None
  • push_to_hub_organization: None
  • mp_parameters:
  • auto_find_batch_size: False
  • full_determinism: False
  • torchdynamo: None
  • ray_scope: last
  • ddp_timeout: 1800
  • torch_compile: False
  • torch_compile_backend: None
  • torch_compile_mode: None
  • dispatch_batches: None
  • split_batches: None
  • include_tokens_per_second: False
  • include_num_input_tokens_seen: False
  • neftune_noise_alpha: None
  • optim_target_modules: None
  • batch_eval_metrics: False
  • eval_on_start: False
  • use_liger_kernel: False
  • eval_use_gather_object: False
  • average_tokens_across_devices: False
  • prompts: None
  • batch_sampler: batch_sampler
  • multi_dataset_batch_sampler: proportional

Training Logs

Epoch Step Training Loss Validation Loss quora-duplicates-dev_average_precision quora-duplicates-test_average_precision
-1 -1 - - 0.3711 -
0.0167 100 0.6574 - - -
0.0333 200 0.4804 - - -
0.0500 300 0.4406 - - -
0.0666 400 0.4208 - - -
0.0833 500 0.3929 0.3958 0.8210 -
0.0999 600 0.3986 - - -
0.1166 700 0.3743 - - -
0.1332 800 0.3938 - - -
0.1499 900 0.3602 - - -
0.1665 1000 0.3714 0.3437 0.8565 -
0.1832 1100 0.3486 - - -
0.1998 1200 0.3479 - - -
0.2165 1300 0.3417 - - -
0.2331 1400 0.3425 - - -
0.2498 1500 0.3353 0.3264 0.8742 -
0.2664 1600 0.3335 - - -
0.2831 1700 0.3274 - - -
0.2998 1800 0.3284 - - -
0.3164 1900 0.3118 - - -
0.3331 2000 0.3073 0.3282 0.8826 -
0.3497 2100 0.3233 - - -
0.3664 2200 0.3072 - - -
0.3830 2300 0.314 - - -
0.3997 2400 0.3065 - - -
0.4163 2500 0.3046 0.2877 0.8930 -
0.4330 2600 0.2857 - - -
0.4496 2700 0.285 - - -
0.4663 2800 0.2957 - - -
0.4829 2900 0.2965 - - -
0.4996 3000 0.2824 0.2842 0.8998 -
0.5162 3100 0.3019 - - -
0.5329 3200 0.2841 - - -
0.5495 3300 0.2981 - - -
0.5662 3400 0.2878 - - -
0.5828 3500 0.278 0.2803 0.9061 -
0.5995 3600 0.2841 - - -
0.6162 3700 0.2794 - - -
0.6328 3800 0.2808 - - -
0.6495 3900 0.27 - - -
0.6661 4000 0.2719 0.2697 0.9091 -
0.6828 4100 0.2792 - - -
0.6994 4200 0.2669 - - -
0.7161 4300 0.2696 - - -
0.7327 4400 0.2642 - - -
0.7494 4500 0.2684 0.2591 0.9140 -
0.7660 4600 0.2593 - - -
0.7827 4700 0.2756 - - -
0.7993 4800 0.2584 - - -
0.8160 4900 0.2525 - - -
0.8326 5000 0.267 0.2540 0.9168 -
0.8493 5100 0.2612 - - -
0.8659 5200 0.2607 - - -
0.8826 5300 0.2565 - - -
0.8993 5400 0.2432 - - -
0.9159 5500 0.2568 0.2489 0.9198 -
0.9326 5600 0.2572 - - -
0.9492 5700 0.2658 - - -
0.9659 5800 0.2568 - - -
0.9825 5900 0.2539 - - -
0.9992 6000 0.2458 0.2503 0.9203 -
-1 -1 - - - 0.9203

Environmental Impact

Carbon emissions were measured using CodeCarbon.

  • Energy Consumed: 0.069 kWh
  • Carbon Emitted: 0.027 kg of CO2
  • Hours Used: 0.214 hours

Training Hardware

  • On Cloud: No
  • GPU Model: 1 x NVIDIA GeForce RTX 3090
  • CPU Model: 13th Gen Intel(R) Core(TM) i7-13700K
  • RAM Size: 31.78 GB

Framework Versions

  • Python: 3.11.6
  • Sentence Transformers: 3.5.0.dev0
  • Transformers: 4.49.0.dev0
  • PyTorch: 2.5.0+cu121
  • Accelerate: 1.3.0
  • Datasets: 2.20.0
  • Tokenizers: 0.21.0

Citation

BibTeX

Sentence Transformers

@inproceedings{reimers-2019-sentence-bert,
    title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
    author = "Reimers, Nils and Gurevych, Iryna",
    booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
    month = "11",
    year = "2019",
    publisher = "Association for Computational Linguistics",
    url = "https://arxiv.org/abs/1908.10084",
}