Pringled commited on
Commit
2f9e086
·
1 Parent(s): 1744dee
Files changed (1) hide show
  1. app.py +254 -4
app.py CHANGED
@@ -199,7 +199,7 @@ with gr.Blocks(css="#status_output { height: 50px; overflow: auto; }") as demo:
199
  """)
200
 
201
  deduplication_type = gr.Radio(
202
- choices=["Single dataset", "Cross-dataset"],
203
  label="Deduplication Type",
204
  value="Cross-dataset", # Set "Cross-dataset" as the default value
205
  )
@@ -218,7 +218,10 @@ with gr.Blocks(css="#status_output { height: 50px; overflow: auto; }") as demo:
218
  dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
219
 
220
  threshold = gr.Slider(0.0, 1.0, value=default_threshold, label="Similarity Threshold")
221
- compute_button = gr.Button("Deduplicate")
 
 
 
222
  status_output = gr.Markdown(elem_id="status_output")
223
  result_output = gr.Markdown()
224
 
@@ -448,7 +451,7 @@ demo.launch()
448
  # deduplication_type = gr.Radio(
449
  # choices=["Single dataset", "Cross-dataset"],
450
  # label="Deduplication Type",
451
- # value="Single dataset",
452
  # )
453
 
454
  # with gr.Row():
@@ -456,7 +459,7 @@ demo.launch()
456
  # dataset1_split = gr.Textbox(value=default_dataset_split, label="Dataset 1 Split")
457
  # dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
458
 
459
- # dataset2_inputs = gr.Column(visible=False)
460
  # with dataset2_inputs:
461
  # gr.Markdown("### Dataset 2")
462
  # with gr.Row():
@@ -491,3 +494,250 @@ demo.launch()
491
 
492
 
493
  # demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  """)
200
 
201
  deduplication_type = gr.Radio(
202
+ choices=["Cross-dataset", "Single dataset"], # Swapped "Cross-dataset" to the left
203
  label="Deduplication Type",
204
  value="Cross-dataset", # Set "Cross-dataset" as the default value
205
  )
 
218
  dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
219
 
220
  threshold = gr.Slider(0.0, 1.0, value=default_threshold, label="Similarity Threshold")
221
+
222
+ with gr.Row(): # Placing the button in the same row for better alignment
223
+ compute_button = gr.Button("Deduplicate")
224
+
225
  status_output = gr.Markdown(elem_id="status_output")
226
  result_output = gr.Markdown()
227
 
 
451
  # deduplication_type = gr.Radio(
452
  # choices=["Single dataset", "Cross-dataset"],
453
  # label="Deduplication Type",
454
+ # value="Cross-dataset", # Set "Cross-dataset" as the default value
455
  # )
456
 
457
  # with gr.Row():
 
459
  # dataset1_split = gr.Textbox(value=default_dataset_split, label="Dataset 1 Split")
460
  # dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
461
 
462
+ # dataset2_inputs = gr.Column(visible=True) # Make dataset2_inputs visible by default
463
  # with dataset2_inputs:
464
  # gr.Markdown("### Dataset 2")
465
  # with gr.Row():
 
494
 
495
 
496
  # demo.launch()
497
+
498
+ # # import gradio as gr
499
+ # # from datasets import load_dataset
500
+ # # import numpy as np
501
+ # # from model2vec import StaticModel
502
+ # # from reach import Reach
503
+ # # from difflib import ndiff
504
+
505
+ # # # Load the model
506
+ # # model = StaticModel.from_pretrained("minishlab/M2V_base_output")
507
+
508
+ # # # Default parameters
509
+ # # default_dataset_name = "sst2"
510
+ # # default_dataset_split = "train"
511
+ # # default_text_column = "sentence"
512
+ # # default_threshold = 0.9
513
+
514
+ # # def deduplicate_embeddings(
515
+ # # embeddings_a: np.ndarray,
516
+ # # embeddings_b: np.ndarray = None,
517
+ # # threshold: float = 0.9,
518
+ # # batch_size: int = 1024,
519
+ # # progress=None
520
+ # # ) -> tuple[np.ndarray, dict[int, int]]:
521
+ # # """
522
+ # # Deduplicate embeddings within one dataset or across two datasets.
523
+
524
+ # # :param embeddings_a: Embeddings of Dataset 1.
525
+ # # :param embeddings_b: Optional, embeddings of Dataset 2.
526
+ # # :param threshold: Similarity threshold for deduplication.
527
+ # # :param batch_size: Batch size for similarity computation.
528
+ # # :param progress: Gradio progress tracker for feedback.
529
+ # # :return: Deduplicated indices and a mapping of removed indices to their original counterparts.
530
+ # # """
531
+ # # if embeddings_b is None:
532
+ # # reach = Reach(vectors=embeddings_a, items=[str(i) for i in range(len(embeddings_a))])
533
+ # # duplicate_to_original = {}
534
+ # # results = reach.nearest_neighbor_threshold(
535
+ # # embeddings_a, threshold=threshold, batch_size=batch_size, show_progressbar=False
536
+ # # )
537
+ # # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=len(embeddings_a))):
538
+ # # for sim_idx, _ in similar_items:
539
+ # # sim_idx = int(sim_idx)
540
+ # # if sim_idx != i and sim_idx not in duplicate_to_original:
541
+ # # duplicate_to_original[sim_idx] = i
542
+ # # deduplicated_indices = set(range(len(embeddings_a))) - set(duplicate_to_original.keys())
543
+ # # return deduplicated_indices, duplicate_to_original
544
+ # # else:
545
+ # # reach = Reach(vectors=embeddings_a, items=[str(i) for i in range(len(embeddings_a))])
546
+ # # duplicate_indices_in_b = []
547
+ # # duplicate_to_original = {}
548
+ # # results = reach.nearest_neighbor_threshold(
549
+ # # embeddings_b, threshold=threshold, batch_size=batch_size, show_progressbar=False
550
+ # # )
551
+ # # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=len(embeddings_b))):
552
+ # # if similar_items:
553
+ # # duplicate_indices_in_b.append(i)
554
+ # # duplicate_to_original[i] = int(similar_items[0][0])
555
+ # # return duplicate_indices_in_b, duplicate_to_original
556
+
557
+ # # def display_word_differences(x: str, y: str) -> str:
558
+ # # """
559
+ # # Display the word-level differences between two texts, formatted to avoid
560
+ # # misinterpretation of Markdown syntax.
561
+
562
+ # # :param x: First text.
563
+ # # :param y: Second text.
564
+ # # :return: A string showing word-level differences, wrapped in a code block.
565
+ # # """
566
+ # # diff = ndiff(x.split(), y.split())
567
+ # # formatted_diff = "\n".join(word for word in diff if word.startswith(("+", "-")))
568
+ # # return f"```\n{formatted_diff}\n```"
569
+
570
+ # # def load_dataset_texts(dataset_name: str, dataset_split: str, text_column: str) -> list[str]:
571
+ # # """
572
+ # # Load texts from a specified dataset and split.
573
+
574
+ # # :param dataset_name: Name of the dataset.
575
+ # # :param dataset_split: Split of the dataset (e.g., 'train', 'validation').
576
+ # # :param text_column: Name of the text column.
577
+ # # :return: A list of texts from the dataset.
578
+ # # """
579
+ # # ds = load_dataset(dataset_name, split=dataset_split)
580
+ # # return [example[text_column] for example in ds]
581
+
582
+ # # def perform_deduplication(
583
+ # # deduplication_type: str,
584
+ # # dataset1_name: str,
585
+ # # dataset1_split: str,
586
+ # # dataset1_text_column: str,
587
+ # # dataset2_name: str = "",
588
+ # # dataset2_split: str = "",
589
+ # # dataset2_text_column: str = "",
590
+ # # threshold: float = default_threshold,
591
+ # # progress: gr.Progress = gr.Progress(track_tqdm=True)
592
+ # # ):
593
+ # # """
594
+ # # Perform deduplication on one or two datasets based on the deduplication type.
595
+
596
+ # # :param deduplication_type: 'Single dataset' or 'Cross-dataset'.
597
+ # # :param dataset1_name: Name of the first dataset.
598
+ # # :param dataset1_split: Split of the first dataset.
599
+ # # :param dataset1_text_column: Text column of the first dataset.
600
+ # # :param dataset2_name: Optional, name of the second dataset (for cross-dataset deduplication).
601
+ # # :param dataset2_split: Optional, split of the second dataset.
602
+ # # :param dataset2_text_column: Optional, text column of the second dataset.
603
+ # # :param threshold: Similarity threshold for deduplication.
604
+ # # :param progress: Gradio progress tracker.
605
+ # # :return: Status updates and result text for the Gradio interface.
606
+ # # """
607
+ # # try:
608
+ # # threshold = float(threshold)
609
+
610
+ # # # Load and process Dataset 1
611
+ # # yield "Loading Dataset 1...", ""
612
+ # # texts1 = load_dataset_texts(dataset1_name, dataset1_split, dataset1_text_column)
613
+ # # yield "Computing embeddings for Dataset 1...", ""
614
+ # # embeddings1 = model.encode(texts1, show_progressbar=True)
615
+
616
+ # # if deduplication_type == "Single dataset":
617
+ # # # Deduplicate within Dataset 1
618
+ # # yield "Deduplicating within Dataset 1...", ""
619
+ # # deduplicated_indices, duplicate_mapping = deduplicate_embeddings(
620
+ # # embeddings1, threshold=threshold, progress=progress
621
+ # # )
622
+
623
+ # # num_duplicates = len(duplicate_mapping)
624
+ # # result_text = (
625
+ # # f"**Total documents:** {len(texts1)}\n\n"
626
+ # # f"**Duplicates found:** {num_duplicates}\n\n"
627
+ # # f"**Unique documents after deduplication:** {len(deduplicated_indices)}\n\n"
628
+ # # )
629
+
630
+ # # if num_duplicates > 0:
631
+ # # result_text += "**Sample duplicates:**\n\n"
632
+ # # for dup_idx, orig_idx in list(duplicate_mapping.items())[:5]:
633
+ # # orig_text = texts1[orig_idx]
634
+ # # dup_text = texts1[dup_idx]
635
+ # # differences = display_word_differences(orig_text, dup_text)
636
+ # # result_text += (
637
+ # # f"**Original:**\n{orig_text}\n\n"
638
+ # # f"**Duplicate:**\n{dup_text}\n\n"
639
+ # # f"**Differences:**\n{differences}\n"
640
+ # # + "-" * 50 + "\n\n"
641
+ # # )
642
+ # # else:
643
+ # # result_text += "No duplicates found."
644
+
645
+ # # yield "Deduplication completed.", result_text
646
+
647
+ # # else:
648
+ # # # Load and process Dataset 2
649
+ # # yield "Loading Dataset 2...", ""
650
+ # # texts2 = load_dataset_texts(dataset2_name, dataset2_split, dataset2_text_column)
651
+ # # yield "Computing embeddings for Dataset 2...", ""
652
+ # # embeddings2 = model.encode(texts2, show_progressbar=True)
653
+
654
+ # # # Deduplicate Dataset 2 against Dataset 1
655
+ # # yield "Deduplicating Dataset 2 against Dataset 1...", ""
656
+ # # duplicate_indices, duplicate_mapping = deduplicate_embeddings(
657
+ # # embeddings1, embeddings_b=embeddings2, threshold=threshold, progress=progress
658
+ # # )
659
+
660
+ # # num_duplicates = len(duplicate_indices)
661
+ # # result_text = (
662
+ # # f"**Total documents in {dataset2_name}/{dataset2_split}:** {len(texts2)}\n\n"
663
+ # # f"**Duplicates found in Dataset 2:** {num_duplicates}\n\n"
664
+ # # f"**Unique documents after deduplication:** {len(texts2) - num_duplicates}\n\n"
665
+ # # )
666
+
667
+ # # if num_duplicates > 0:
668
+ # # result_text += "**Sample duplicates from Dataset 2:**\n\n"
669
+ # # for idx in duplicate_indices[:5]:
670
+ # # orig_text = texts1[duplicate_mapping[idx]]
671
+ # # dup_text = texts2[idx]
672
+ # # differences = display_word_differences(orig_text, dup_text)
673
+ # # result_text += (
674
+ # # f"**Original (Dataset 1):**\n{orig_text}\n\n"
675
+ # # f"**Duplicate (Dataset 2):**\n{dup_text}\n\n"
676
+ # # f"**Differences:**\n{differences}\n"
677
+ # # + "-" * 50 + "\n\n"
678
+ # # )
679
+ # # else:
680
+ # # result_text += "No duplicates found."
681
+
682
+ # # yield "Deduplication completed.", result_text
683
+
684
+ # # except Exception as e:
685
+ # # yield f"An error occurred: {e}", ""
686
+ # # raise e
687
+
688
+ # # # Gradio app with stop button support
689
+ # # with gr.Blocks(css="#status_output { height: 50px; overflow: auto; }") as demo:
690
+ # # gr.Markdown("# Semantic Deduplication")
691
+ # # gr.Markdown("""
692
+ # # This demo showcases semantic deduplication using Model2Vec for HuggingFace datasets.
693
+ # # It can be used to identify duplicate texts within a single dataset or across two datasets.
694
+ # # You can adjust the similarity threshold to control the strictness of the deduplication.\n
695
+ # # NOTE: this demo runs on a free CPU backend, so it may be slow for large datasets. For faster results, please run the code locally.
696
+ # # """)
697
+
698
+ # # deduplication_type = gr.Radio(
699
+ # # choices=["Single dataset", "Cross-dataset"],
700
+ # # label="Deduplication Type",
701
+ # # value="Single dataset",
702
+ # # )
703
+
704
+ # # with gr.Row():
705
+ # # dataset1_name = gr.Textbox(value=default_dataset_name, label="Dataset 1 Name")
706
+ # # dataset1_split = gr.Textbox(value=default_dataset_split, label="Dataset 1 Split")
707
+ # # dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
708
+
709
+ # # dataset2_inputs = gr.Column(visible=False)
710
+ # # with dataset2_inputs:
711
+ # # gr.Markdown("### Dataset 2")
712
+ # # with gr.Row():
713
+ # # dataset2_name = gr.Textbox(value=default_dataset_name, label="Dataset 2 Name")
714
+ # # dataset2_split = gr.Textbox(value=default_dataset_split, label="Dataset 2 Split")
715
+ # # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
716
+
717
+ # # threshold = gr.Slider(0.0, 1.0, value=default_threshold, label="Similarity Threshold")
718
+ # # compute_button = gr.Button("Deduplicate")
719
+ # # status_output = gr.Markdown(elem_id="status_output")
720
+ # # result_output = gr.Markdown()
721
+
722
+ # # def update_visibility(choice: str):
723
+ # # return gr.update(visible=choice == "Cross-dataset")
724
+
725
+ # # deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=dataset2_inputs)
726
+
727
+ # # compute_button.click(
728
+ # # fn=perform_deduplication,
729
+ # # inputs=[
730
+ # # deduplication_type,
731
+ # # dataset1_name,
732
+ # # dataset1_split,
733
+ # # dataset1_text_column,
734
+ # # dataset2_name,
735
+ # # dataset2_split,
736
+ # # dataset2_text_column,
737
+ # # threshold,
738
+ # # ],
739
+ # # outputs=[status_output, result_output],
740
+ # # )
741
+
742
+
743
+ # # demo.launch()