Commit 
							
							·
						
						cd47483
	
1
								Parent(s):
							
							0202688
								
add support for custom BASE_URL, MODEL, APIKEY
Browse files- README.md +7 -1
- app.py +5 -5
- pyproject.toml +12 -4
- src/distilabel_dataset_generator/__init__.py +0 -26
- src/distilabel_dataset_generator/apps/__init__.py +0 -0
- src/distilabel_dataset_generator/apps/base.py +1 -1
- src/distilabel_dataset_generator/apps/eval.py +5 -7
- src/distilabel_dataset_generator/apps/sft.py +169 -164
- src/distilabel_dataset_generator/apps/textcat.py +1 -3
- src/distilabel_dataset_generator/constants.py +55 -0
- src/distilabel_dataset_generator/pipelines/__init__.py +0 -0
- src/distilabel_dataset_generator/pipelines/base.py +2 -4
- src/distilabel_dataset_generator/pipelines/embeddings.py +1 -1
- src/distilabel_dataset_generator/pipelines/eval.py +15 -14
- src/distilabel_dataset_generator/pipelines/sft.py +15 -6
- src/distilabel_dataset_generator/pipelines/textcat.py +13 -14
- src/distilabel_dataset_generator/utils.py +1 -1
    	
        README.md
    CHANGED
    
    | @@ -80,7 +80,13 @@ pip install synthetic-dataset-generator | |
| 80 |  | 
| 81 | 
             
            ### Environment Variables
         | 
| 82 |  | 
| 83 | 
            -
            - `HF_TOKEN`: Your Hugging Face token | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 84 |  | 
| 85 | 
             
            Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:
         | 
| 86 |  | 
|  | |
| 80 |  | 
| 81 | 
             
            ### Environment Variables
         | 
| 82 |  | 
| 83 | 
            +
            - `HF_TOKEN`: Your [Hugging Face token](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained) to push your datasets to the Hugging Face Hub and generate free completions from Hugging Face Inference Endpoints.
         | 
| 84 | 
            +
             | 
| 85 | 
            +
            Optionally, you can set the following environment variables to customize the generation process.
         | 
| 86 | 
            +
             | 
| 87 | 
            +
            - `BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api-inference.huggingface.co/v1/`.
         | 
| 88 | 
            +
            - `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`.
         | 
| 89 | 
            +
            - `API_KEY`: The API key to use for the corresponding API, e.g. `hf_...`.
         | 
| 90 |  | 
| 91 | 
             
            Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:
         | 
| 92 |  | 
    	
        app.py
    CHANGED
    
    | @@ -1,8 +1,8 @@ | |
| 1 | 
            -
            from  | 
| 2 | 
            -
            from  | 
| 3 | 
            -
            from  | 
| 4 | 
            -
            from  | 
| 5 | 
            -
            from  | 
| 6 |  | 
| 7 | 
             
            theme = "argilla/argilla-theme"
         | 
| 8 |  | 
|  | |
| 1 | 
            +
            from distilabel_dataset_generator._tabbedinterface import TabbedInterface
         | 
| 2 | 
            +
            from distilabel_dataset_generator.apps.eval import app as eval_app
         | 
| 3 | 
            +
            from distilabel_dataset_generator.apps.faq import app as faq_app
         | 
| 4 | 
            +
            from distilabel_dataset_generator.apps.sft import app as sft_app
         | 
| 5 | 
            +
            from distilabel_dataset_generator.apps.textcat import app as textcat_app
         | 
| 6 |  | 
| 7 | 
             
            theme = "argilla/argilla-theme"
         | 
| 8 |  | 
    	
        pyproject.toml
    CHANGED
    
    | @@ -5,6 +5,18 @@ description = "Build datasets using natural language" | |
| 5 | 
             
            authors = [
         | 
| 6 | 
             
                {name = "davidberenstein1957", email = "[email protected]"},
         | 
| 7 | 
             
            ]
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 8 | 
             
            dependencies = [
         | 
| 9 | 
             
                "distilabel[hf-inference-endpoints,argilla,outlines,instructor]>=1.4.1",
         | 
| 10 | 
             
                "gradio[oauth]<5.0.0",
         | 
| @@ -14,14 +26,10 @@ dependencies = [ | |
| 14 | 
             
                "gradio-huggingfacehub-search>=0.0.7",
         | 
| 15 | 
             
                "argilla>=2.4.0",
         | 
| 16 | 
             
            ]
         | 
| 17 | 
            -
            requires-python = "<3.13,>=3.10"
         | 
| 18 | 
            -
            readme = "README.md"
         | 
| 19 | 
            -
            license = {text = "apache 2"}
         | 
| 20 |  | 
| 21 | 
             
            [build-system]
         | 
| 22 | 
             
            requires = ["pdm-backend"]
         | 
| 23 | 
             
            build-backend = "pdm.backend"
         | 
| 24 |  | 
| 25 | 
            -
             | 
| 26 | 
             
            [tool.pdm]
         | 
| 27 | 
             
            distribution = true
         | 
|  | |
| 5 | 
             
            authors = [
         | 
| 6 | 
             
                {name = "davidberenstein1957", email = "[email protected]"},
         | 
| 7 | 
             
            ]
         | 
| 8 | 
            +
            tags = [
         | 
| 9 | 
            +
                "gradio",
         | 
| 10 | 
            +
                "synthetic-data",
         | 
| 11 | 
            +
                "huggingface",
         | 
| 12 | 
            +
                "argilla",
         | 
| 13 | 
            +
                "generative-ai",
         | 
| 14 | 
            +
                "ai",
         | 
| 15 | 
            +
            ]
         | 
| 16 | 
            +
            requires-python = "<3.13,>=3.10"
         | 
| 17 | 
            +
            readme = "README.md"
         | 
| 18 | 
            +
            license = {text = "Apache 2"}
         | 
| 19 | 
            +
             | 
| 20 | 
             
            dependencies = [
         | 
| 21 | 
             
                "distilabel[hf-inference-endpoints,argilla,outlines,instructor]>=1.4.1",
         | 
| 22 | 
             
                "gradio[oauth]<5.0.0",
         | 
|  | |
| 26 | 
             
                "gradio-huggingfacehub-search>=0.0.7",
         | 
| 27 | 
             
                "argilla>=2.4.0",
         | 
| 28 | 
             
            ]
         | 
|  | |
|  | |
|  | |
| 29 |  | 
| 30 | 
             
            [build-system]
         | 
| 31 | 
             
            requires = ["pdm-backend"]
         | 
| 32 | 
             
            build-backend = "pdm.backend"
         | 
| 33 |  | 
|  | |
| 34 | 
             
            [tool.pdm]
         | 
| 35 | 
             
            distribution = true
         | 
    	
        src/distilabel_dataset_generator/__init__.py
    CHANGED
    
    | @@ -1,8 +1,5 @@ | |
| 1 | 
            -
            import os
         | 
| 2 | 
            -
            import warnings
         | 
| 3 | 
             
            from typing import Optional
         | 
| 4 |  | 
| 5 | 
            -
            import argilla as rg
         | 
| 6 | 
             
            import distilabel
         | 
| 7 | 
             
            import distilabel.distiset
         | 
| 8 | 
             
            from distilabel.utils.card.dataset_card import (
         | 
| @@ -11,29 +8,6 @@ from distilabel.utils.card.dataset_card import ( | |
| 11 | 
             
            )
         | 
| 12 | 
             
            from huggingface_hub import DatasetCardData, HfApi
         | 
| 13 |  | 
| 14 | 
            -
            HF_TOKENS = [os.getenv("HF_TOKEN")] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
         | 
| 15 | 
            -
            HF_TOKENS = [token for token in HF_TOKENS if token]
         | 
| 16 | 
            -
             | 
| 17 | 
            -
            if len(HF_TOKENS) == 0:
         | 
| 18 | 
            -
                raise ValueError(
         | 
| 19 | 
            -
                    "HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
         | 
| 20 | 
            -
                )
         | 
| 21 | 
            -
             | 
| 22 | 
            -
            ARGILLA_API_URL = os.getenv("ARGILLA_API_URL")
         | 
| 23 | 
            -
            ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY")
         | 
| 24 | 
            -
            if ARGILLA_API_URL is None or ARGILLA_API_KEY is None:
         | 
| 25 | 
            -
                ARGILLA_API_URL = os.getenv("ARGILLA_API_URL_SDG_REVIEWER")
         | 
| 26 | 
            -
                ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY_SDG_REVIEWER")
         | 
| 27 | 
            -
             | 
| 28 | 
            -
            if ARGILLA_API_URL is None or ARGILLA_API_KEY is None:
         | 
| 29 | 
            -
                warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set")
         | 
| 30 | 
            -
                argilla_client = None
         | 
| 31 | 
            -
            else:
         | 
| 32 | 
            -
                argilla_client = rg.Argilla(
         | 
| 33 | 
            -
                    api_url=ARGILLA_API_URL,
         | 
| 34 | 
            -
                    api_key=ARGILLA_API_KEY,
         | 
| 35 | 
            -
                )
         | 
| 36 | 
            -
             | 
| 37 |  | 
| 38 | 
             
            class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset):
         | 
| 39 | 
             
                def _generate_card(
         | 
|  | |
|  | |
|  | |
| 1 | 
             
            from typing import Optional
         | 
| 2 |  | 
|  | |
| 3 | 
             
            import distilabel
         | 
| 4 | 
             
            import distilabel.distiset
         | 
| 5 | 
             
            from distilabel.utils.card.dataset_card import (
         | 
|  | |
| 8 | 
             
            )
         | 
| 9 | 
             
            from huggingface_hub import DatasetCardData, HfApi
         | 
| 10 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 11 |  | 
| 12 | 
             
            class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset):
         | 
| 13 | 
             
                def _generate_card(
         | 
    	
        src/distilabel_dataset_generator/apps/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        src/distilabel_dataset_generator/apps/base.py
    CHANGED
    
    | @@ -10,7 +10,7 @@ from distilabel.distiset import Distiset | |
| 10 | 
             
            from gradio import OAuthToken
         | 
| 11 | 
             
            from huggingface_hub import HfApi, upload_file
         | 
| 12 |  | 
| 13 | 
            -
            from  | 
| 14 | 
             
                _LOGGED_OUT_CSS,
         | 
| 15 | 
             
                get_argilla_client,
         | 
| 16 | 
             
                get_login_button,
         | 
|  | |
| 10 | 
             
            from gradio import OAuthToken
         | 
| 11 | 
             
            from huggingface_hub import HfApi, upload_file
         | 
| 12 |  | 
| 13 | 
            +
            from distilabel_dataset_generator.utils import (
         | 
| 14 | 
             
                _LOGGED_OUT_CSS,
         | 
| 15 | 
             
                get_argilla_client,
         | 
| 16 | 
             
                get_login_button,
         | 
    	
        src/distilabel_dataset_generator/apps/eval.py
    CHANGED
    
    | @@ -16,25 +16,23 @@ from distilabel.distiset import Distiset | |
| 16 | 
             
            from gradio_huggingfacehub_search import HuggingfaceHubSearch
         | 
| 17 | 
             
            from huggingface_hub import HfApi
         | 
| 18 |  | 
| 19 | 
            -
            from  | 
| 20 | 
             
                hide_success_message,
         | 
| 21 | 
             
                show_success_message,
         | 
| 22 | 
             
                validate_argilla_user_workspace_dataset,
         | 
| 23 | 
             
                validate_push_to_hub,
         | 
| 24 | 
             
            )
         | 
| 25 | 
            -
            from  | 
| 26 | 
            -
             | 
| 27 | 
            -
            )
         | 
| 28 | 
            -
            from src.distilabel_dataset_generator.pipelines.embeddings import (
         | 
| 29 | 
             
                get_embeddings,
         | 
| 30 | 
             
                get_sentence_embedding_dimensions,
         | 
| 31 | 
             
            )
         | 
| 32 | 
            -
            from  | 
| 33 | 
             
                generate_pipeline_code,
         | 
| 34 | 
             
                get_custom_evaluator,
         | 
| 35 | 
             
                get_ultrafeedback_evaluator,
         | 
| 36 | 
             
            )
         | 
| 37 | 
            -
            from  | 
| 38 | 
             
                column_to_list,
         | 
| 39 | 
             
                extract_column_names,
         | 
| 40 | 
             
                get_argilla_client,
         | 
|  | |
| 16 | 
             
            from gradio_huggingfacehub_search import HuggingfaceHubSearch
         | 
| 17 | 
             
            from huggingface_hub import HfApi
         | 
| 18 |  | 
| 19 | 
            +
            from distilabel_dataset_generator.apps.base import (
         | 
| 20 | 
             
                hide_success_message,
         | 
| 21 | 
             
                show_success_message,
         | 
| 22 | 
             
                validate_argilla_user_workspace_dataset,
         | 
| 23 | 
             
                validate_push_to_hub,
         | 
| 24 | 
             
            )
         | 
| 25 | 
            +
            from distilabel_dataset_generator.constants import DEFAULT_BATCH_SIZE
         | 
| 26 | 
            +
            from distilabel_dataset_generator.pipelines.embeddings import (
         | 
|  | |
|  | |
| 27 | 
             
                get_embeddings,
         | 
| 28 | 
             
                get_sentence_embedding_dimensions,
         | 
| 29 | 
             
            )
         | 
| 30 | 
            +
            from distilabel_dataset_generator.pipelines.eval import (
         | 
| 31 | 
             
                generate_pipeline_code,
         | 
| 32 | 
             
                get_custom_evaluator,
         | 
| 33 | 
             
                get_ultrafeedback_evaluator,
         | 
| 34 | 
             
            )
         | 
| 35 | 
            +
            from distilabel_dataset_generator.utils import (
         | 
| 36 | 
             
                column_to_list,
         | 
| 37 | 
             
                extract_column_names,
         | 
| 38 | 
             
                get_argilla_client,
         | 
    	
        src/distilabel_dataset_generator/apps/sft.py
    CHANGED
    
    | @@ -9,27 +9,25 @@ from datasets import Dataset | |
| 9 | 
             
            from distilabel.distiset import Distiset
         | 
| 10 | 
             
            from huggingface_hub import HfApi
         | 
| 11 |  | 
| 12 | 
            -
            from  | 
| 13 | 
             
                hide_success_message,
         | 
| 14 | 
             
                show_success_message,
         | 
| 15 | 
             
                validate_argilla_user_workspace_dataset,
         | 
| 16 | 
             
                validate_push_to_hub,
         | 
| 17 | 
             
            )
         | 
| 18 | 
            -
            from  | 
| 19 | 
            -
             | 
| 20 | 
            -
            )
         | 
| 21 | 
            -
            from src.distilabel_dataset_generator.pipelines.embeddings import (
         | 
| 22 | 
             
                get_embeddings,
         | 
| 23 | 
             
                get_sentence_embedding_dimensions,
         | 
| 24 | 
             
            )
         | 
| 25 | 
            -
            from  | 
| 26 | 
             
                DEFAULT_DATASET_DESCRIPTIONS,
         | 
| 27 | 
             
                generate_pipeline_code,
         | 
| 28 | 
             
                get_magpie_generator,
         | 
| 29 | 
             
                get_prompt_generator,
         | 
| 30 | 
             
                get_response_generator,
         | 
| 31 | 
             
            )
         | 
| 32 | 
            -
            from  | 
| 33 | 
             
                _LOGGED_OUT_CSS,
         | 
| 34 | 
             
                get_argilla_client,
         | 
| 35 | 
             
                get_org_dropdown,
         | 
| @@ -354,168 +352,175 @@ def hide_pipeline_code_visibility(): | |
| 354 |  | 
| 355 | 
             
            with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
         | 
| 356 | 
             
                with gr.Column() as main_ui:
         | 
| 357 | 
            -
                     | 
| 358 | 
            -
             | 
| 359 | 
            -
             | 
| 360 | 
            -
             | 
| 361 | 
            -
             | 
| 362 | 
            -
             | 
| 363 | 
            -
             | 
| 364 | 
            -
                            with gr. | 
| 365 | 
            -
                                 | 
| 366 | 
            -
                                     | 
| 367 | 
            -
                                     | 
| 368 | 
            -
             | 
| 369 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 370 | 
             
                                    interactive=True,
         | 
| 371 | 
            -
                                     | 
| 372 | 
             
                                )
         | 
| 373 | 
            -
             | 
| 374 | 
            -
             | 
| 375 | 
            -
                                variant="primary",
         | 
| 376 | 
            -
                            )
         | 
| 377 | 
            -
                        with gr.Column(scale=2):
         | 
| 378 | 
            -
                            examples = gr.Examples(
         | 
| 379 | 
            -
                                examples=DEFAULT_DATASET_DESCRIPTIONS,
         | 
| 380 | 
            -
                                inputs=[dataset_description],
         | 
| 381 | 
            -
                                cache_examples=False,
         | 
| 382 | 
            -
                                label="Examples",
         | 
| 383 | 
            -
                            )
         | 
| 384 | 
            -
                        with gr.Column(scale=1):
         | 
| 385 | 
            -
                            pass
         | 
| 386 | 
            -
             | 
| 387 | 
            -
                    gr.HTML(value="<hr>")
         | 
| 388 | 
            -
                    gr.Markdown(value="## 2. Configure your dataset")
         | 
| 389 | 
            -
                    with gr.Row(equal_height=False):
         | 
| 390 | 
            -
                        with gr.Column(scale=2):
         | 
| 391 | 
            -
                            system_prompt = gr.Textbox(
         | 
| 392 | 
            -
                                label="System prompt",
         | 
| 393 | 
            -
                                placeholder="You are a helpful assistant.",
         | 
| 394 | 
            -
                            )
         | 
| 395 | 
            -
                            num_turns = gr.Number(
         | 
| 396 | 
            -
                                value=1,
         | 
| 397 | 
            -
                                label="Number of turns in the conversation",
         | 
| 398 | 
            -
                                minimum=1,
         | 
| 399 | 
            -
                                maximum=4,
         | 
| 400 | 
            -
                                step=1,
         | 
| 401 | 
            -
                                interactive=True,
         | 
| 402 | 
            -
                                info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
         | 
| 403 | 
            -
                            )
         | 
| 404 | 
            -
                            btn_apply_to_sample_dataset = gr.Button(
         | 
| 405 | 
            -
                                "Refresh dataset", variant="secondary"
         | 
| 406 | 
            -
                            )
         | 
| 407 | 
            -
                        with gr.Column(scale=3):
         | 
| 408 | 
            -
                            dataframe = gr.Dataframe(
         | 
| 409 | 
            -
                                headers=["prompt", "completion"],
         | 
| 410 | 
            -
                                wrap=True,
         | 
| 411 | 
            -
                                height=500,
         | 
| 412 | 
            -
                                interactive=False,
         | 
| 413 | 
            -
                            )
         | 
| 414 | 
            -
             | 
| 415 | 
            -
                    gr.HTML(value="<hr>")
         | 
| 416 | 
            -
                    gr.Markdown(value="## 3. Generate your dataset")
         | 
| 417 | 
            -
                    with gr.Row(equal_height=False):
         | 
| 418 | 
            -
                        with gr.Column(scale=2):
         | 
| 419 | 
            -
                            org_name = get_org_dropdown()
         | 
| 420 | 
            -
                            repo_name = gr.Textbox(
         | 
| 421 | 
            -
                                label="Repo name",
         | 
| 422 | 
            -
                                placeholder="dataset_name",
         | 
| 423 | 
            -
                                value=f"my-distiset-{str(uuid.uuid4())[:8]}",
         | 
| 424 | 
            -
                                interactive=True,
         | 
| 425 | 
            -
                            )
         | 
| 426 | 
            -
                            num_rows = gr.Number(
         | 
| 427 | 
            -
                                label="Number of rows",
         | 
| 428 | 
            -
                                value=10,
         | 
| 429 | 
            -
                                interactive=True,
         | 
| 430 | 
            -
                                scale=1,
         | 
| 431 | 
            -
                            )
         | 
| 432 | 
            -
                            private = gr.Checkbox(
         | 
| 433 | 
            -
                                label="Private dataset",
         | 
| 434 | 
            -
                                value=False,
         | 
| 435 | 
            -
                                interactive=True,
         | 
| 436 | 
            -
                                scale=1,
         | 
| 437 | 
            -
                            )
         | 
| 438 | 
            -
                            btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
         | 
| 439 | 
            -
                        with gr.Column(scale=3):
         | 
| 440 | 
            -
                            success_message = gr.Markdown(visible=True)
         | 
| 441 | 
            -
                            with gr.Accordion(
         | 
| 442 | 
            -
                                "Do you want to go further? Customize and run with Distilabel",
         | 
| 443 | 
            -
                                open=False,
         | 
| 444 | 
            -
                                visible=False,
         | 
| 445 | 
            -
                            ) as pipeline_code_ui:
         | 
| 446 | 
            -
                                code = generate_pipeline_code(
         | 
| 447 | 
            -
                                    system_prompt=system_prompt.value,
         | 
| 448 | 
            -
                                    num_turns=num_turns.value,
         | 
| 449 | 
            -
                                    num_rows=num_rows.value,
         | 
| 450 | 
             
                                )
         | 
| 451 | 
            -
             | 
| 452 | 
            -
             | 
| 453 | 
            -
                                     | 
| 454 | 
            -
                                     | 
|  | |
|  | |
| 455 | 
             
                                )
         | 
| 456 |  | 
| 457 | 
            -
             | 
| 458 | 
            -
             | 
| 459 | 
            -
             | 
| 460 | 
            -
             | 
| 461 | 
            -
             | 
| 462 | 
            -
             | 
| 463 | 
            -
             | 
| 464 | 
            -
             | 
| 465 | 
            -
             | 
| 466 | 
            -
             | 
| 467 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 468 |  | 
| 469 | 
            -
             | 
| 470 | 
            -
             | 
| 471 | 
            -
             | 
| 472 | 
            -
             | 
| 473 | 
            -
             | 
| 474 | 
            -
             | 
| 475 |  | 
| 476 | 
            -
             | 
| 477 | 
            -
             | 
| 478 | 
            -
             | 
| 479 | 
            -
             | 
| 480 | 
            -
             | 
| 481 | 
            -
             | 
| 482 | 
            -
             | 
| 483 | 
            -
             | 
| 484 | 
            -
             | 
| 485 | 
            -
             | 
| 486 | 
            -
             | 
| 487 | 
            -
             | 
| 488 | 
            -
             | 
| 489 | 
            -
             | 
| 490 | 
            -
             | 
| 491 | 
            -
             | 
| 492 | 
            -
             | 
| 493 | 
            -
             | 
| 494 | 
            -
             | 
| 495 | 
            -
             | 
| 496 | 
            -
             | 
| 497 | 
            -
             | 
| 498 | 
            -
             | 
| 499 | 
            -
             | 
| 500 | 
            -
             | 
| 501 | 
            -
             | 
| 502 | 
            -
             | 
| 503 | 
            -
             | 
| 504 | 
            -
             | 
| 505 | 
            -
             | 
| 506 | 
            -
             | 
| 507 | 
            -
             | 
| 508 | 
            -
             | 
| 509 | 
            -
             | 
| 510 | 
            -
             | 
| 511 | 
            -
             | 
| 512 | 
            -
             | 
| 513 | 
            -
             | 
| 514 | 
            -
             | 
| 515 | 
            -
             | 
| 516 | 
            -
             | 
| 517 | 
            -
             | 
| 518 | 
            -
             | 
| 519 |  | 
| 520 | 
            -
             | 
| 521 | 
            -
             | 
|  | |
| 9 | 
             
            from distilabel.distiset import Distiset
         | 
| 10 | 
             
            from huggingface_hub import HfApi
         | 
| 11 |  | 
| 12 | 
            +
            from distilabel_dataset_generator.apps.base import (
         | 
| 13 | 
             
                hide_success_message,
         | 
| 14 | 
             
                show_success_message,
         | 
| 15 | 
             
                validate_argilla_user_workspace_dataset,
         | 
| 16 | 
             
                validate_push_to_hub,
         | 
| 17 | 
             
            )
         | 
| 18 | 
            +
            from distilabel_dataset_generator.constants import DEFAULT_BATCH_SIZE, SFT_AVAILABLE
         | 
| 19 | 
            +
            from distilabel_dataset_generator.pipelines.embeddings import (
         | 
|  | |
|  | |
| 20 | 
             
                get_embeddings,
         | 
| 21 | 
             
                get_sentence_embedding_dimensions,
         | 
| 22 | 
             
            )
         | 
| 23 | 
            +
            from distilabel_dataset_generator.pipelines.sft import (
         | 
| 24 | 
             
                DEFAULT_DATASET_DESCRIPTIONS,
         | 
| 25 | 
             
                generate_pipeline_code,
         | 
| 26 | 
             
                get_magpie_generator,
         | 
| 27 | 
             
                get_prompt_generator,
         | 
| 28 | 
             
                get_response_generator,
         | 
| 29 | 
             
            )
         | 
| 30 | 
            +
            from distilabel_dataset_generator.utils import (
         | 
| 31 | 
             
                _LOGGED_OUT_CSS,
         | 
| 32 | 
             
                get_argilla_client,
         | 
| 33 | 
             
                get_org_dropdown,
         | 
|  | |
| 352 |  | 
| 353 | 
             
            with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
         | 
| 354 | 
             
                with gr.Column() as main_ui:
         | 
| 355 | 
            +
                    if not SFT_AVAILABLE:
         | 
| 356 | 
            +
                        gr.Markdown(
         | 
| 357 | 
            +
                            value=f"## Supervised Fine-Tuning is not available for the {MODEL} model. Use Hugging Face Llama3 or Qwen2 models."
         | 
| 358 | 
            +
                        )
         | 
| 359 | 
            +
                    else:
         | 
| 360 | 
            +
                        gr.Markdown(value="## 1. Describe the dataset you want")
         | 
| 361 | 
            +
                        with gr.Row():
         | 
| 362 | 
            +
                            with gr.Column(scale=2):
         | 
| 363 | 
            +
                                dataset_description = gr.Textbox(
         | 
| 364 | 
            +
                                    label="Dataset description",
         | 
| 365 | 
            +
                                    placeholder="Give a precise description of your desired dataset.",
         | 
| 366 | 
            +
                                )
         | 
| 367 | 
            +
                                with gr.Accordion("Temperature", open=False):
         | 
| 368 | 
            +
                                    temperature = gr.Slider(
         | 
| 369 | 
            +
                                        minimum=0.1,
         | 
| 370 | 
            +
                                        maximum=1,
         | 
| 371 | 
            +
                                        value=0.8,
         | 
| 372 | 
            +
                                        step=0.1,
         | 
| 373 | 
            +
                                        interactive=True,
         | 
| 374 | 
            +
                                        show_label=False,
         | 
| 375 | 
            +
                                    )
         | 
| 376 | 
            +
                                load_btn = gr.Button(
         | 
| 377 | 
            +
                                    "Create dataset",
         | 
| 378 | 
            +
                                    variant="primary",
         | 
| 379 | 
            +
                                )
         | 
| 380 | 
            +
                            with gr.Column(scale=2):
         | 
| 381 | 
            +
                                examples = gr.Examples(
         | 
| 382 | 
            +
                                    examples=DEFAULT_DATASET_DESCRIPTIONS,
         | 
| 383 | 
            +
                                    inputs=[dataset_description],
         | 
| 384 | 
            +
                                    cache_examples=False,
         | 
| 385 | 
            +
                                    label="Examples",
         | 
| 386 | 
            +
                                )
         | 
| 387 | 
            +
                            with gr.Column(scale=1):
         | 
| 388 | 
            +
                                pass
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                        gr.HTML(value="<hr>")
         | 
| 391 | 
            +
                        gr.Markdown(value="## 2. Configure your dataset")
         | 
| 392 | 
            +
                        with gr.Row(equal_height=False):
         | 
| 393 | 
            +
                            with gr.Column(scale=2):
         | 
| 394 | 
            +
                                system_prompt = gr.Textbox(
         | 
| 395 | 
            +
                                    label="System prompt",
         | 
| 396 | 
            +
                                    placeholder="You are a helpful assistant.",
         | 
| 397 | 
            +
                                )
         | 
| 398 | 
            +
                                num_turns = gr.Number(
         | 
| 399 | 
            +
                                    value=1,
         | 
| 400 | 
            +
                                    label="Number of turns in the conversation",
         | 
| 401 | 
            +
                                    minimum=1,
         | 
| 402 | 
            +
                                    maximum=4,
         | 
| 403 | 
            +
                                    step=1,
         | 
| 404 | 
             
                                    interactive=True,
         | 
| 405 | 
            +
                                    info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
         | 
| 406 | 
             
                                )
         | 
| 407 | 
            +
                                btn_apply_to_sample_dataset = gr.Button(
         | 
| 408 | 
            +
                                    "Refresh dataset", variant="secondary"
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 409 | 
             
                                )
         | 
| 410 | 
            +
                            with gr.Column(scale=3):
         | 
| 411 | 
            +
                                dataframe = gr.Dataframe(
         | 
| 412 | 
            +
                                    headers=["prompt", "completion"],
         | 
| 413 | 
            +
                                    wrap=True,
         | 
| 414 | 
            +
                                    height=500,
         | 
| 415 | 
            +
                                    interactive=False,
         | 
| 416 | 
             
                                )
         | 
| 417 |  | 
| 418 | 
            +
                        gr.HTML(value="<hr>")
         | 
| 419 | 
            +
                        gr.Markdown(value="## 3. Generate your dataset")
         | 
| 420 | 
            +
                        with gr.Row(equal_height=False):
         | 
| 421 | 
            +
                            with gr.Column(scale=2):
         | 
| 422 | 
            +
                                org_name = get_org_dropdown()
         | 
| 423 | 
            +
                                repo_name = gr.Textbox(
         | 
| 424 | 
            +
                                    label="Repo name",
         | 
| 425 | 
            +
                                    placeholder="dataset_name",
         | 
| 426 | 
            +
                                    value=f"my-distiset-{str(uuid.uuid4())[:8]}",
         | 
| 427 | 
            +
                                    interactive=True,
         | 
| 428 | 
            +
                                )
         | 
| 429 | 
            +
                                num_rows = gr.Number(
         | 
| 430 | 
            +
                                    label="Number of rows",
         | 
| 431 | 
            +
                                    value=10,
         | 
| 432 | 
            +
                                    interactive=True,
         | 
| 433 | 
            +
                                    scale=1,
         | 
| 434 | 
            +
                                )
         | 
| 435 | 
            +
                                private = gr.Checkbox(
         | 
| 436 | 
            +
                                    label="Private dataset",
         | 
| 437 | 
            +
                                    value=False,
         | 
| 438 | 
            +
                                    interactive=True,
         | 
| 439 | 
            +
                                    scale=1,
         | 
| 440 | 
            +
                                )
         | 
| 441 | 
            +
                                btn_push_to_hub = gr.Button(
         | 
| 442 | 
            +
                                    "Push to Hub", variant="primary", scale=2
         | 
| 443 | 
            +
                                )
         | 
| 444 | 
            +
                            with gr.Column(scale=3):
         | 
| 445 | 
            +
                                success_message = gr.Markdown(visible=True)
         | 
| 446 | 
            +
                                with gr.Accordion(
         | 
| 447 | 
            +
                                    "Do you want to go further? Customize and run with Distilabel",
         | 
| 448 | 
            +
                                    open=False,
         | 
| 449 | 
            +
                                    visible=False,
         | 
| 450 | 
            +
                                ) as pipeline_code_ui:
         | 
| 451 | 
            +
                                    code = generate_pipeline_code(
         | 
| 452 | 
            +
                                        system_prompt=system_prompt.value,
         | 
| 453 | 
            +
                                        num_turns=num_turns.value,
         | 
| 454 | 
            +
                                        num_rows=num_rows.value,
         | 
| 455 | 
            +
                                    )
         | 
| 456 | 
            +
                                    pipeline_code = gr.Code(
         | 
| 457 | 
            +
                                        value=code,
         | 
| 458 | 
            +
                                        language="python",
         | 
| 459 | 
            +
                                        label="Distilabel Pipeline Code",
         | 
| 460 | 
            +
                                    )
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                    load_btn.click(
         | 
| 463 | 
            +
                        fn=generate_system_prompt,
         | 
| 464 | 
            +
                        inputs=[dataset_description, temperature],
         | 
| 465 | 
            +
                        outputs=[system_prompt],
         | 
| 466 | 
            +
                        show_progress=True,
         | 
| 467 | 
            +
                    ).then(
         | 
| 468 | 
            +
                        fn=generate_sample_dataset,
         | 
| 469 | 
            +
                        inputs=[system_prompt, num_turns],
         | 
| 470 | 
            +
                        outputs=[dataframe],
         | 
| 471 | 
            +
                        show_progress=True,
         | 
| 472 | 
            +
                    )
         | 
| 473 |  | 
| 474 | 
            +
                    btn_apply_to_sample_dataset.click(
         | 
| 475 | 
            +
                        fn=generate_sample_dataset,
         | 
| 476 | 
            +
                        inputs=[system_prompt, num_turns],
         | 
| 477 | 
            +
                        outputs=[dataframe],
         | 
| 478 | 
            +
                        show_progress=True,
         | 
| 479 | 
            +
                    )
         | 
| 480 |  | 
| 481 | 
            +
                    btn_push_to_hub.click(
         | 
| 482 | 
            +
                        fn=validate_argilla_user_workspace_dataset,
         | 
| 483 | 
            +
                        inputs=[repo_name],
         | 
| 484 | 
            +
                        outputs=[success_message],
         | 
| 485 | 
            +
                        show_progress=True,
         | 
| 486 | 
            +
                    ).then(
         | 
| 487 | 
            +
                        fn=validate_push_to_hub,
         | 
| 488 | 
            +
                        inputs=[org_name, repo_name],
         | 
| 489 | 
            +
                        outputs=[success_message],
         | 
| 490 | 
            +
                        show_progress=True,
         | 
| 491 | 
            +
                    ).success(
         | 
| 492 | 
            +
                        fn=hide_success_message,
         | 
| 493 | 
            +
                        outputs=[success_message],
         | 
| 494 | 
            +
                        show_progress=True,
         | 
| 495 | 
            +
                    ).success(
         | 
| 496 | 
            +
                        fn=hide_pipeline_code_visibility,
         | 
| 497 | 
            +
                        inputs=[],
         | 
| 498 | 
            +
                        outputs=[pipeline_code_ui],
         | 
| 499 | 
            +
                    ).success(
         | 
| 500 | 
            +
                        fn=push_dataset,
         | 
| 501 | 
            +
                        inputs=[
         | 
| 502 | 
            +
                            org_name,
         | 
| 503 | 
            +
                            repo_name,
         | 
| 504 | 
            +
                            system_prompt,
         | 
| 505 | 
            +
                            num_turns,
         | 
| 506 | 
            +
                            num_rows,
         | 
| 507 | 
            +
                            private,
         | 
| 508 | 
            +
                        ],
         | 
| 509 | 
            +
                        outputs=[success_message],
         | 
| 510 | 
            +
                        show_progress=True,
         | 
| 511 | 
            +
                    ).success(
         | 
| 512 | 
            +
                        fn=show_success_message,
         | 
| 513 | 
            +
                        inputs=[org_name, repo_name],
         | 
| 514 | 
            +
                        outputs=[success_message],
         | 
| 515 | 
            +
                    ).success(
         | 
| 516 | 
            +
                        fn=generate_pipeline_code,
         | 
| 517 | 
            +
                        inputs=[system_prompt, num_turns, num_rows],
         | 
| 518 | 
            +
                        outputs=[pipeline_code],
         | 
| 519 | 
            +
                    ).success(
         | 
| 520 | 
            +
                        fn=show_pipeline_code_visibility,
         | 
| 521 | 
            +
                        inputs=[],
         | 
| 522 | 
            +
                        outputs=[pipeline_code_ui],
         | 
| 523 | 
            +
                    )
         | 
| 524 |  | 
| 525 | 
            +
                    app.load(fn=swap_visibility, outputs=main_ui)
         | 
| 526 | 
            +
                    app.load(fn=get_org_dropdown, outputs=[org_name])
         | 
    	
        src/distilabel_dataset_generator/apps/textcat.py
    CHANGED
    
    | @@ -9,15 +9,13 @@ from datasets import ClassLabel, Dataset, Features, Sequence, Value | |
| 9 | 
             
            from distilabel.distiset import Distiset
         | 
| 10 | 
             
            from huggingface_hub import HfApi
         | 
| 11 |  | 
|  | |
| 12 | 
             
            from src.distilabel_dataset_generator.apps.base import (
         | 
| 13 | 
             
                hide_success_message,
         | 
| 14 | 
             
                show_success_message,
         | 
| 15 | 
             
                validate_argilla_user_workspace_dataset,
         | 
| 16 | 
             
                validate_push_to_hub,
         | 
| 17 | 
             
            )
         | 
| 18 | 
            -
            from src.distilabel_dataset_generator.pipelines.base import (
         | 
| 19 | 
            -
                DEFAULT_BATCH_SIZE,
         | 
| 20 | 
            -
            )
         | 
| 21 | 
             
            from src.distilabel_dataset_generator.pipelines.embeddings import (
         | 
| 22 | 
             
                get_embeddings,
         | 
| 23 | 
             
                get_sentence_embedding_dimensions,
         | 
|  | |
| 9 | 
             
            from distilabel.distiset import Distiset
         | 
| 10 | 
             
            from huggingface_hub import HfApi
         | 
| 11 |  | 
| 12 | 
            +
            from distilabel_dataset_generator.constants import DEFAULT_BATCH_SIZE
         | 
| 13 | 
             
            from src.distilabel_dataset_generator.apps.base import (
         | 
| 14 | 
             
                hide_success_message,
         | 
| 15 | 
             
                show_success_message,
         | 
| 16 | 
             
                validate_argilla_user_workspace_dataset,
         | 
| 17 | 
             
                validate_push_to_hub,
         | 
| 18 | 
             
            )
         | 
|  | |
|  | |
|  | |
| 19 | 
             
            from src.distilabel_dataset_generator.pipelines.embeddings import (
         | 
| 20 | 
             
                get_embeddings,
         | 
| 21 | 
             
                get_sentence_embedding_dimensions,
         | 
    	
        src/distilabel_dataset_generator/constants.py
    ADDED
    
    | @@ -0,0 +1,55 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import warnings
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import argilla as rg
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # Hugging Face
         | 
| 7 | 
            +
            HF_TOKEN = os.getenv("HF_TOKEN")
         | 
| 8 | 
            +
            if HF_TOKEN is None:
         | 
| 9 | 
            +
                raise ValueError(
         | 
| 10 | 
            +
                    "HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
         | 
| 11 | 
            +
                )
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # Inference
         | 
| 14 | 
            +
            DEFAULT_BATCH_SIZE = 5
         | 
| 15 | 
            +
            MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
         | 
| 16 | 
            +
            API_KEYS = (
         | 
| 17 | 
            +
                [os.getenv("HF_TOKEN")]
         | 
| 18 | 
            +
                + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
         | 
| 19 | 
            +
                + [os.getenv("API_KEY")]
         | 
| 20 | 
            +
            )
         | 
| 21 | 
            +
            API_KEYS = [token for token in API_KEYS if token]
         | 
| 22 | 
            +
            BASE_URL = os.getenv("BASE_URL", "https://api-inference.huggingface.co/v1/")
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            if BASE_URL != "https://api-inference.huggingface.co/v1/" and len(API_KEYS) == 0:
         | 
| 25 | 
            +
                raise ValueError(
         | 
| 26 | 
            +
                    "API_KEY is not set. Ensure you have set the API_KEY environment variable that has access to the Hugging Face Inference Endpoints."
         | 
| 27 | 
            +
                )
         | 
| 28 | 
            +
            if "Qwen2" not in MODEL and "Llama-3" not in MODEL:
         | 
| 29 | 
            +
                SFT_AVAILABLE = False
         | 
| 30 | 
            +
                warnings.warn(
         | 
| 31 | 
            +
                    "SFT_AVAILABLE is set to False because the model is not a Qwen or Llama model."
         | 
| 32 | 
            +
                )
         | 
| 33 | 
            +
                MAGPIE_PRE_QUERY_TEMPLATE = None
         | 
| 34 | 
            +
            else:
         | 
| 35 | 
            +
                SFT_AVAILABLE = True
         | 
| 36 | 
            +
                if "Qwen2" in MODEL:
         | 
| 37 | 
            +
                    MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
         | 
| 38 | 
            +
                else:
         | 
| 39 | 
            +
                    MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            # Argilla
         | 
| 42 | 
            +
            ARGILLA_API_URL = os.getenv("ARGILLA_API_URL")
         | 
| 43 | 
            +
            ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY")
         | 
| 44 | 
            +
            if ARGILLA_API_URL is None or ARGILLA_API_KEY is None:
         | 
| 45 | 
            +
                ARGILLA_API_URL = os.getenv("ARGILLA_API_URL_SDG_REVIEWER")
         | 
| 46 | 
            +
                ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY_SDG_REVIEWER")
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            if ARGILLA_API_URL is None or ARGILLA_API_KEY is None:
         | 
| 49 | 
            +
                warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set")
         | 
| 50 | 
            +
                argilla_client = None
         | 
| 51 | 
            +
            else:
         | 
| 52 | 
            +
                argilla_client = rg.Argilla(
         | 
| 53 | 
            +
                    api_url=ARGILLA_API_URL,
         | 
| 54 | 
            +
                    api_key=ARGILLA_API_KEY,
         | 
| 55 | 
            +
                )
         | 
    	
        src/distilabel_dataset_generator/pipelines/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        src/distilabel_dataset_generator/pipelines/base.py
    CHANGED
    
    | @@ -1,12 +1,10 @@ | |
| 1 | 
            -
            from  | 
| 2 |  | 
| 3 | 
            -
            DEFAULT_BATCH_SIZE = 5
         | 
| 4 | 
             
            TOKEN_INDEX = 0
         | 
| 5 | 
            -
            MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
         | 
| 6 |  | 
| 7 |  | 
| 8 | 
             
            def _get_next_api_key():
         | 
| 9 | 
             
                global TOKEN_INDEX
         | 
| 10 | 
            -
                api_key =  | 
| 11 | 
             
                TOKEN_INDEX += 1
         | 
| 12 | 
             
                return api_key
         | 
|  | |
| 1 | 
            +
            from distilabel_dataset_generator.constants import API_KEYS
         | 
| 2 |  | 
|  | |
| 3 | 
             
            TOKEN_INDEX = 0
         | 
|  | |
| 4 |  | 
| 5 |  | 
| 6 | 
             
            def _get_next_api_key():
         | 
| 7 | 
             
                global TOKEN_INDEX
         | 
| 8 | 
            +
                api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)]
         | 
| 9 | 
             
                TOKEN_INDEX += 1
         | 
| 10 | 
             
                return api_key
         | 
    	
        src/distilabel_dataset_generator/pipelines/embeddings.py
    CHANGED
    
    | @@ -4,7 +4,7 @@ from sentence_transformers import SentenceTransformer | |
| 4 | 
             
            from sentence_transformers.models import StaticEmbedding
         | 
| 5 |  | 
| 6 | 
             
            # Initialize a StaticEmbedding module
         | 
| 7 | 
            -
            static_embedding = StaticEmbedding.from_model2vec("minishlab/ | 
| 8 | 
             
            model = SentenceTransformer(modules=[static_embedding])
         | 
| 9 |  | 
| 10 |  | 
|  | |
| 4 | 
             
            from sentence_transformers.models import StaticEmbedding
         | 
| 5 |  | 
| 6 | 
             
            # Initialize a StaticEmbedding module
         | 
| 7 | 
            +
            static_embedding = StaticEmbedding.from_model2vec("minishlab/potion-base-8M")
         | 
| 8 | 
             
            model = SentenceTransformer(modules=[static_embedding])
         | 
| 9 |  | 
| 10 |  | 
    	
        src/distilabel_dataset_generator/pipelines/eval.py
    CHANGED
    
    | @@ -5,18 +5,16 @@ from distilabel.steps.tasks import ( | |
| 5 | 
             
                UltraFeedback,
         | 
| 6 | 
             
            )
         | 
| 7 |  | 
| 8 | 
            -
            from  | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
| 11 | 
            -
            )
         | 
| 12 | 
            -
            from src.distilabel_dataset_generator.utils import extract_column_names
         | 
| 13 |  | 
| 14 |  | 
| 15 | 
             
            def get_ultrafeedback_evaluator(aspect, is_sample):
         | 
| 16 | 
             
                ultrafeedback_evaluator = UltraFeedback(
         | 
| 17 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 18 | 
             
                        model_id=MODEL,
         | 
| 19 | 
            -
                         | 
| 20 | 
             
                        api_key=_get_next_api_key(),
         | 
| 21 | 
             
                        generation_kwargs={
         | 
| 22 | 
             
                            "temperature": 0,
         | 
| @@ -33,7 +31,7 @@ def get_custom_evaluator(prompt_template, structured_output, columns, is_sample) | |
| 33 | 
             
                custom_evaluator = TextGeneration(
         | 
| 34 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 35 | 
             
                        model_id=MODEL,
         | 
| 36 | 
            -
                         | 
| 37 | 
             
                        api_key=_get_next_api_key(),
         | 
| 38 | 
             
                        structured_output={"format": "json", "schema": structured_output},
         | 
| 39 | 
             
                        generation_kwargs={
         | 
| @@ -62,7 +60,8 @@ from distilabel.steps.tasks import UltraFeedback | |
| 62 | 
             
            from distilabel.llms import InferenceEndpointsLLM
         | 
| 63 |  | 
| 64 | 
             
            MODEL = "{MODEL}"
         | 
| 65 | 
            -
             | 
|  | |
| 66 |  | 
| 67 | 
             
            hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}[:{num_rows}]")
         | 
| 68 | 
             
            data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries
         | 
| @@ -76,8 +75,8 @@ with Pipeline(name="ultrafeedback") as pipeline: | |
| 76 | 
             
                ultrafeedback_evaluator = UltraFeedback(
         | 
| 77 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 78 | 
             
                        model_id=MODEL,
         | 
| 79 | 
            -
                         | 
| 80 | 
            -
                        api_key=os.environ[" | 
| 81 | 
             
                        generation_kwargs={{
         | 
| 82 | 
             
                            "temperature": 0,
         | 
| 83 | 
             
                            "max_new_tokens": 2048,
         | 
| @@ -101,7 +100,8 @@ from distilabel.steps.tasks import UltraFeedback | |
| 101 | 
             
            from distilabel.llms import InferenceEndpointsLLM
         | 
| 102 |  | 
| 103 | 
             
            MODEL = "{MODEL}"
         | 
| 104 | 
            -
             | 
|  | |
| 105 |  | 
| 106 | 
             
            hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}")
         | 
| 107 | 
             
            data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries
         | 
| @@ -119,8 +119,8 @@ with Pipeline(name="ultrafeedback") as pipeline: | |
| 119 | 
             
                        aspect=aspect,
         | 
| 120 | 
             
                        llm=InferenceEndpointsLLM(
         | 
| 121 | 
             
                            model_id=MODEL,
         | 
| 122 | 
            -
                             | 
| 123 | 
            -
                            api_key=os.environ[" | 
| 124 | 
             
                            generation_kwargs={{
         | 
| 125 | 
             
                                "temperature": 0,
         | 
| 126 | 
             
                                "max_new_tokens": 2048,
         | 
| @@ -157,6 +157,7 @@ from distilabel.steps.tasks import TextGeneration | |
| 157 | 
             
            from distilabel.llms import InferenceEndpointsLLM
         | 
| 158 |  | 
| 159 | 
             
            MODEL = "{MODEL}"
         | 
|  | |
| 160 | 
             
            CUSTOM_TEMPLATE = "{prompt_template}"
         | 
| 161 | 
             
            os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
         | 
| 162 |  | 
| @@ -171,7 +172,7 @@ with Pipeline(name="custom-evaluation") as pipeline: | |
| 171 | 
             
                custom_evaluator = TextGeneration(
         | 
| 172 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 173 | 
             
                        model_id=MODEL,
         | 
| 174 | 
            -
                         | 
| 175 | 
             
                        api_key=os.environ["HF_TOKEN"],
         | 
| 176 | 
             
                        structured_output={{"format": "json", "schema": {structured_output}}},
         | 
| 177 | 
             
                        generation_kwargs={{
         | 
|  | |
| 5 | 
             
                UltraFeedback,
         | 
| 6 | 
             
            )
         | 
| 7 |  | 
| 8 | 
            +
            from distilabel_dataset_generator.constants import BASE_URL, MODEL
         | 
| 9 | 
            +
            from distilabel_dataset_generator.pipelines.base import _get_next_api_key
         | 
| 10 | 
            +
            from distilabel_dataset_generator.utils import extract_column_names
         | 
|  | |
|  | |
| 11 |  | 
| 12 |  | 
| 13 | 
             
            def get_ultrafeedback_evaluator(aspect, is_sample):
         | 
| 14 | 
             
                ultrafeedback_evaluator = UltraFeedback(
         | 
| 15 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 16 | 
             
                        model_id=MODEL,
         | 
| 17 | 
            +
                        base_url=BASE_URL,
         | 
| 18 | 
             
                        api_key=_get_next_api_key(),
         | 
| 19 | 
             
                        generation_kwargs={
         | 
| 20 | 
             
                            "temperature": 0,
         | 
|  | |
| 31 | 
             
                custom_evaluator = TextGeneration(
         | 
| 32 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 33 | 
             
                        model_id=MODEL,
         | 
| 34 | 
            +
                        base_url=BASE_URL,
         | 
| 35 | 
             
                        api_key=_get_next_api_key(),
         | 
| 36 | 
             
                        structured_output={"format": "json", "schema": structured_output},
         | 
| 37 | 
             
                        generation_kwargs={
         | 
|  | |
| 60 | 
             
            from distilabel.llms import InferenceEndpointsLLM
         | 
| 61 |  | 
| 62 | 
             
            MODEL = "{MODEL}"
         | 
| 63 | 
            +
            BASE_URL = "{BASE_URL}"
         | 
| 64 | 
            +
            os.environ["API_KEY"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
         | 
| 65 |  | 
| 66 | 
             
            hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}[:{num_rows}]")
         | 
| 67 | 
             
            data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries
         | 
|  | |
| 75 | 
             
                ultrafeedback_evaluator = UltraFeedback(
         | 
| 76 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 77 | 
             
                        model_id=MODEL,
         | 
| 78 | 
            +
                        base_url=BASE_URL,
         | 
| 79 | 
            +
                        api_key=os.environ["API_KEY"],
         | 
| 80 | 
             
                        generation_kwargs={{
         | 
| 81 | 
             
                            "temperature": 0,
         | 
| 82 | 
             
                            "max_new_tokens": 2048,
         | 
|  | |
| 100 | 
             
            from distilabel.llms import InferenceEndpointsLLM
         | 
| 101 |  | 
| 102 | 
             
            MODEL = "{MODEL}"
         | 
| 103 | 
            +
            BASE_URL = "{BASE_URL}"
         | 
| 104 | 
            +
            os.environ["BASE_URL"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
         | 
| 105 |  | 
| 106 | 
             
            hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}")
         | 
| 107 | 
             
            data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries
         | 
|  | |
| 119 | 
             
                        aspect=aspect,
         | 
| 120 | 
             
                        llm=InferenceEndpointsLLM(
         | 
| 121 | 
             
                            model_id=MODEL,
         | 
| 122 | 
            +
                            base_url=BASE_URL,
         | 
| 123 | 
            +
                            api_key=os.environ["BASE_URL"],
         | 
| 124 | 
             
                            generation_kwargs={{
         | 
| 125 | 
             
                                "temperature": 0,
         | 
| 126 | 
             
                                "max_new_tokens": 2048,
         | 
|  | |
| 157 | 
             
            from distilabel.llms import InferenceEndpointsLLM
         | 
| 158 |  | 
| 159 | 
             
            MODEL = "{MODEL}"
         | 
| 160 | 
            +
            BASE_URL = "{BASE_URL}"
         | 
| 161 | 
             
            CUSTOM_TEMPLATE = "{prompt_template}"
         | 
| 162 | 
             
            os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
         | 
| 163 |  | 
|  | |
| 172 | 
             
                custom_evaluator = TextGeneration(
         | 
| 173 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 174 | 
             
                        model_id=MODEL,
         | 
| 175 | 
            +
                        base_url=BASE_URL,
         | 
| 176 | 
             
                        api_key=os.environ["HF_TOKEN"],
         | 
| 177 | 
             
                        structured_output={{"format": "json", "schema": {structured_output}}},
         | 
| 178 | 
             
                        generation_kwargs={{
         | 
    	
        src/distilabel_dataset_generator/pipelines/sft.py
    CHANGED
    
    | @@ -1,10 +1,12 @@ | |
| 1 | 
             
            from distilabel.llms import InferenceEndpointsLLM
         | 
| 2 | 
             
            from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
         | 
| 3 |  | 
| 4 | 
            -
            from  | 
|  | |
|  | |
| 5 | 
             
                MODEL,
         | 
| 6 | 
            -
                _get_next_api_key,
         | 
| 7 | 
             
            )
         | 
|  | |
| 8 |  | 
| 9 | 
             
            INFORMATION_SEEKING_PROMPT = (
         | 
| 10 | 
             
                "You are an AI assistant designed to provide accurate and concise information on a wide"
         | 
| @@ -144,6 +146,7 @@ def get_prompt_generator(temperature): | |
| 144 | 
             
                        api_key=_get_next_api_key(),
         | 
| 145 | 
             
                        model_id=MODEL,
         | 
| 146 | 
             
                        tokenizer_id=MODEL,
         | 
|  | |
| 147 | 
             
                        generation_kwargs={
         | 
| 148 | 
             
                            "temperature": temperature,
         | 
| 149 | 
             
                            "max_new_tokens": 2048,
         | 
| @@ -165,8 +168,9 @@ def get_magpie_generator(system_prompt, num_turns, is_sample): | |
| 165 | 
             
                        llm=InferenceEndpointsLLM(
         | 
| 166 | 
             
                            model_id=MODEL,
         | 
| 167 | 
             
                            tokenizer_id=MODEL,
         | 
|  | |
| 168 | 
             
                            api_key=_get_next_api_key(),
         | 
| 169 | 
            -
                            magpie_pre_query_template= | 
| 170 | 
             
                            generation_kwargs={
         | 
| 171 | 
             
                                "temperature": 0.9,
         | 
| 172 | 
             
                                "do_sample": True,
         | 
| @@ -184,8 +188,9 @@ def get_magpie_generator(system_prompt, num_turns, is_sample): | |
| 184 | 
             
                        llm=InferenceEndpointsLLM(
         | 
| 185 | 
             
                            model_id=MODEL,
         | 
| 186 | 
             
                            tokenizer_id=MODEL,
         | 
|  | |
| 187 | 
             
                            api_key=_get_next_api_key(),
         | 
| 188 | 
            -
                            magpie_pre_query_template= | 
| 189 | 
             
                            generation_kwargs={
         | 
| 190 | 
             
                                "temperature": 0.9,
         | 
| 191 | 
             
                                "do_sample": True,
         | 
| @@ -208,6 +213,7 @@ def get_response_generator(system_prompt, num_turns, is_sample): | |
| 208 | 
             
                        llm=InferenceEndpointsLLM(
         | 
| 209 | 
             
                            model_id=MODEL,
         | 
| 210 | 
             
                            tokenizer_id=MODEL,
         | 
|  | |
| 211 | 
             
                            api_key=_get_next_api_key(),
         | 
| 212 | 
             
                            generation_kwargs={
         | 
| 213 | 
             
                                "temperature": 0.8,
         | 
| @@ -223,6 +229,7 @@ def get_response_generator(system_prompt, num_turns, is_sample): | |
| 223 | 
             
                        llm=InferenceEndpointsLLM(
         | 
| 224 | 
             
                            model_id=MODEL,
         | 
| 225 | 
             
                            tokenizer_id=MODEL,
         | 
|  | |
| 226 | 
             
                            api_key=_get_next_api_key(),
         | 
| 227 | 
             
                            generation_kwargs={
         | 
| 228 | 
             
                                "temperature": 0.8,
         | 
| @@ -247,14 +254,16 @@ from distilabel.steps.tasks import MagpieGenerator | |
| 247 | 
             
            from distilabel.llms import InferenceEndpointsLLM
         | 
| 248 |  | 
| 249 | 
             
            MODEL = "{MODEL}"
         | 
|  | |
| 250 | 
             
            SYSTEM_PROMPT = "{system_prompt}"
         | 
| 251 | 
            -
            os.environ[" | 
| 252 |  | 
| 253 | 
             
            with Pipeline(name="sft") as pipeline:
         | 
| 254 | 
             
                magpie = MagpieGenerator(
         | 
| 255 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 256 | 
             
                        model_id=MODEL,
         | 
| 257 | 
             
                        tokenizer_id=MODEL,
         | 
|  | |
| 258 | 
             
                        magpie_pre_query_template="llama3",
         | 
| 259 | 
             
                        generation_kwargs={{
         | 
| 260 | 
             
                            "temperature": 0.9,
         | 
| @@ -262,7 +271,7 @@ with Pipeline(name="sft") as pipeline: | |
| 262 | 
             
                            "max_new_tokens": 2048,
         | 
| 263 | 
             
                            "stop_sequences": {_STOP_SEQUENCES}
         | 
| 264 | 
             
                        }},
         | 
| 265 | 
            -
                        api_key=os.environ[" | 
| 266 | 
             
                    ),
         | 
| 267 | 
             
                    n_turns={num_turns},
         | 
| 268 | 
             
                    num_rows={num_rows},
         | 
|  | |
| 1 | 
             
            from distilabel.llms import InferenceEndpointsLLM
         | 
| 2 | 
             
            from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
         | 
| 3 |  | 
| 4 | 
            +
            from distilabel_dataset_generator.constants import (
         | 
| 5 | 
            +
                BASE_URL,
         | 
| 6 | 
            +
                MAGPIE_PRE_QUERY_TEMPLATE,
         | 
| 7 | 
             
                MODEL,
         | 
|  | |
| 8 | 
             
            )
         | 
| 9 | 
            +
            from distilabel_dataset_generator.pipelines.base import _get_next_api_key
         | 
| 10 |  | 
| 11 | 
             
            INFORMATION_SEEKING_PROMPT = (
         | 
| 12 | 
             
                "You are an AI assistant designed to provide accurate and concise information on a wide"
         | 
|  | |
| 146 | 
             
                        api_key=_get_next_api_key(),
         | 
| 147 | 
             
                        model_id=MODEL,
         | 
| 148 | 
             
                        tokenizer_id=MODEL,
         | 
| 149 | 
            +
                        base_url=BASE_URL,
         | 
| 150 | 
             
                        generation_kwargs={
         | 
| 151 | 
             
                            "temperature": temperature,
         | 
| 152 | 
             
                            "max_new_tokens": 2048,
         | 
|  | |
| 168 | 
             
                        llm=InferenceEndpointsLLM(
         | 
| 169 | 
             
                            model_id=MODEL,
         | 
| 170 | 
             
                            tokenizer_id=MODEL,
         | 
| 171 | 
            +
                            base_url=BASE_URL,
         | 
| 172 | 
             
                            api_key=_get_next_api_key(),
         | 
| 173 | 
            +
                            magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
         | 
| 174 | 
             
                            generation_kwargs={
         | 
| 175 | 
             
                                "temperature": 0.9,
         | 
| 176 | 
             
                                "do_sample": True,
         | 
|  | |
| 188 | 
             
                        llm=InferenceEndpointsLLM(
         | 
| 189 | 
             
                            model_id=MODEL,
         | 
| 190 | 
             
                            tokenizer_id=MODEL,
         | 
| 191 | 
            +
                            base_url=BASE_URL,
         | 
| 192 | 
             
                            api_key=_get_next_api_key(),
         | 
| 193 | 
            +
                            magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
         | 
| 194 | 
             
                            generation_kwargs={
         | 
| 195 | 
             
                                "temperature": 0.9,
         | 
| 196 | 
             
                                "do_sample": True,
         | 
|  | |
| 213 | 
             
                        llm=InferenceEndpointsLLM(
         | 
| 214 | 
             
                            model_id=MODEL,
         | 
| 215 | 
             
                            tokenizer_id=MODEL,
         | 
| 216 | 
            +
                            base_url=BASE_URL,
         | 
| 217 | 
             
                            api_key=_get_next_api_key(),
         | 
| 218 | 
             
                            generation_kwargs={
         | 
| 219 | 
             
                                "temperature": 0.8,
         | 
|  | |
| 229 | 
             
                        llm=InferenceEndpointsLLM(
         | 
| 230 | 
             
                            model_id=MODEL,
         | 
| 231 | 
             
                            tokenizer_id=MODEL,
         | 
| 232 | 
            +
                            base_url=BASE_URL,
         | 
| 233 | 
             
                            api_key=_get_next_api_key(),
         | 
| 234 | 
             
                            generation_kwargs={
         | 
| 235 | 
             
                                "temperature": 0.8,
         | 
|  | |
| 254 | 
             
            from distilabel.llms import InferenceEndpointsLLM
         | 
| 255 |  | 
| 256 | 
             
            MODEL = "{MODEL}"
         | 
| 257 | 
            +
            BASE_URL = "{BASE_URL}"
         | 
| 258 | 
             
            SYSTEM_PROMPT = "{system_prompt}"
         | 
| 259 | 
            +
            os.environ["API_KEY"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
         | 
| 260 |  | 
| 261 | 
             
            with Pipeline(name="sft") as pipeline:
         | 
| 262 | 
             
                magpie = MagpieGenerator(
         | 
| 263 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 264 | 
             
                        model_id=MODEL,
         | 
| 265 | 
             
                        tokenizer_id=MODEL,
         | 
| 266 | 
            +
                        base_url=BASE_URL,
         | 
| 267 | 
             
                        magpie_pre_query_template="llama3",
         | 
| 268 | 
             
                        generation_kwargs={{
         | 
| 269 | 
             
                            "temperature": 0.9,
         | 
|  | |
| 271 | 
             
                            "max_new_tokens": 2048,
         | 
| 272 | 
             
                            "stop_sequences": {_STOP_SEQUENCES}
         | 
| 273 | 
             
                        }},
         | 
| 274 | 
            +
                        api_key=os.environ["BASE_URL"],
         | 
| 275 | 
             
                    ),
         | 
| 276 | 
             
                    n_turns={num_turns},
         | 
| 277 | 
             
                    num_rows={num_rows},
         | 
    	
        src/distilabel_dataset_generator/pipelines/textcat.py
    CHANGED
    
    | @@ -1,5 +1,4 @@ | |
| 1 | 
             
            import random
         | 
| 2 | 
            -
            from pydantic import BaseModel, Field
         | 
| 3 | 
             
            from typing import List
         | 
| 4 |  | 
| 5 | 
             
            from distilabel.llms import InferenceEndpointsLLM
         | 
| @@ -8,12 +7,11 @@ from distilabel.steps.tasks import ( | |
| 8 | 
             
                TextClassification,
         | 
| 9 | 
             
                TextGeneration,
         | 
| 10 | 
             
            )
         | 
|  | |
| 11 |  | 
| 12 | 
            -
            from  | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
            )
         | 
| 16 | 
            -
            from src.distilabel_dataset_generator.utils import get_preprocess_labels
         | 
| 17 |  | 
| 18 | 
             
            PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
         | 
| 19 |  | 
| @@ -73,7 +71,7 @@ def get_prompt_generator(temperature): | |
| 73 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 74 | 
             
                        api_key=_get_next_api_key(),
         | 
| 75 | 
             
                        model_id=MODEL,
         | 
| 76 | 
            -
                         | 
| 77 | 
             
                        structured_output={"format": "json", "schema": TextClassificationTask},
         | 
| 78 | 
             
                        generation_kwargs={
         | 
| 79 | 
             
                            "temperature": temperature,
         | 
| @@ -92,7 +90,7 @@ def get_textcat_generator(difficulty, clarity, is_sample): | |
| 92 | 
             
                textcat_generator = GenerateTextClassificationData(
         | 
| 93 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 94 | 
             
                        model_id=MODEL,
         | 
| 95 | 
            -
                         | 
| 96 | 
             
                        api_key=_get_next_api_key(),
         | 
| 97 | 
             
                        generation_kwargs={
         | 
| 98 | 
             
                            "temperature": 0.9,
         | 
| @@ -114,7 +112,7 @@ def get_labeller_generator(system_prompt, labels, num_labels): | |
| 114 | 
             
                labeller_generator = TextClassification(
         | 
| 115 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 116 | 
             
                        model_id=MODEL,
         | 
| 117 | 
            -
                         | 
| 118 | 
             
                        api_key=_get_next_api_key(),
         | 
| 119 | 
             
                        generation_kwargs={
         | 
| 120 | 
             
                            "temperature": 0.7,
         | 
| @@ -149,8 +147,9 @@ from distilabel.steps import LoadDataFromDicts, KeepColumns | |
| 149 | 
             
            from distilabel.steps.tasks import {"GenerateTextClassificationData" if num_labels == 1 else "GenerateTextClassificationData, TextClassification"}
         | 
| 150 |  | 
| 151 | 
             
            MODEL = "{MODEL}"
         | 
|  | |
| 152 | 
             
            TEXT_CLASSIFICATION_TASK = "{system_prompt}"
         | 
| 153 | 
            -
            os.environ[" | 
| 154 | 
             
                "hf_xxx"  # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
         | 
| 155 | 
             
            )
         | 
| 156 |  | 
| @@ -161,8 +160,8 @@ with Pipeline(name="textcat") as pipeline: | |
| 161 | 
             
                textcat_generation = GenerateTextClassificationData(
         | 
| 162 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 163 | 
             
                        model_id=MODEL,
         | 
| 164 | 
            -
                         | 
| 165 | 
            -
                        api_key=os.environ[" | 
| 166 | 
             
                        generation_kwargs={{
         | 
| 167 | 
             
                            "temperature": 0.8,
         | 
| 168 | 
             
                            "max_new_tokens": 2048,
         | 
| @@ -205,8 +204,8 @@ with Pipeline(name="textcat") as pipeline: | |
| 205 | 
             
                textcat_labeller = TextClassification(
         | 
| 206 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 207 | 
             
                        model_id=MODEL,
         | 
| 208 | 
            -
                         | 
| 209 | 
            -
                        api_key=os.environ[" | 
| 210 | 
             
                        generation_kwargs={{
         | 
| 211 | 
             
                            "temperature": 0.8,
         | 
| 212 | 
             
                            "max_new_tokens": 2048,
         | 
|  | |
| 1 | 
             
            import random
         | 
|  | |
| 2 | 
             
            from typing import List
         | 
| 3 |  | 
| 4 | 
             
            from distilabel.llms import InferenceEndpointsLLM
         | 
|  | |
| 7 | 
             
                TextClassification,
         | 
| 8 | 
             
                TextGeneration,
         | 
| 9 | 
             
            )
         | 
| 10 | 
            +
            from pydantic import BaseModel, Field
         | 
| 11 |  | 
| 12 | 
            +
            from distilabel_dataset_generator.constants import BASE_URL, MODEL
         | 
| 13 | 
            +
            from distilabel_dataset_generator.pipelines.base import _get_next_api_key
         | 
| 14 | 
            +
            from distilabel_dataset_generator.utils import get_preprocess_labels
         | 
|  | |
|  | |
| 15 |  | 
| 16 | 
             
            PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
         | 
| 17 |  | 
|  | |
| 71 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 72 | 
             
                        api_key=_get_next_api_key(),
         | 
| 73 | 
             
                        model_id=MODEL,
         | 
| 74 | 
            +
                        base_url=BASE_URL,
         | 
| 75 | 
             
                        structured_output={"format": "json", "schema": TextClassificationTask},
         | 
| 76 | 
             
                        generation_kwargs={
         | 
| 77 | 
             
                            "temperature": temperature,
         | 
|  | |
| 90 | 
             
                textcat_generator = GenerateTextClassificationData(
         | 
| 91 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 92 | 
             
                        model_id=MODEL,
         | 
| 93 | 
            +
                        base_url=BASE_URL,
         | 
| 94 | 
             
                        api_key=_get_next_api_key(),
         | 
| 95 | 
             
                        generation_kwargs={
         | 
| 96 | 
             
                            "temperature": 0.9,
         | 
|  | |
| 112 | 
             
                labeller_generator = TextClassification(
         | 
| 113 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 114 | 
             
                        model_id=MODEL,
         | 
| 115 | 
            +
                        base_url=BASE_URL,
         | 
| 116 | 
             
                        api_key=_get_next_api_key(),
         | 
| 117 | 
             
                        generation_kwargs={
         | 
| 118 | 
             
                            "temperature": 0.7,
         | 
|  | |
| 147 | 
             
            from distilabel.steps.tasks import {"GenerateTextClassificationData" if num_labels == 1 else "GenerateTextClassificationData, TextClassification"}
         | 
| 148 |  | 
| 149 | 
             
            MODEL = "{MODEL}"
         | 
| 150 | 
            +
            BASE_URL = "{BASE_URL}"
         | 
| 151 | 
             
            TEXT_CLASSIFICATION_TASK = "{system_prompt}"
         | 
| 152 | 
            +
            os.environ["API_KEY"] = (
         | 
| 153 | 
             
                "hf_xxx"  # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
         | 
| 154 | 
             
            )
         | 
| 155 |  | 
|  | |
| 160 | 
             
                textcat_generation = GenerateTextClassificationData(
         | 
| 161 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 162 | 
             
                        model_id=MODEL,
         | 
| 163 | 
            +
                        base_url=BASE_URL,
         | 
| 164 | 
            +
                        api_key=os.environ["API_KEY"],
         | 
| 165 | 
             
                        generation_kwargs={{
         | 
| 166 | 
             
                            "temperature": 0.8,
         | 
| 167 | 
             
                            "max_new_tokens": 2048,
         | 
|  | |
| 204 | 
             
                textcat_labeller = TextClassification(
         | 
| 205 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 206 | 
             
                        model_id=MODEL,
         | 
| 207 | 
            +
                        base_url=BASE_URL,
         | 
| 208 | 
            +
                        api_key=os.environ["API_KEY"],
         | 
| 209 | 
             
                        generation_kwargs={{
         | 
| 210 | 
             
                            "temperature": 0.8,
         | 
| 211 | 
             
                            "max_new_tokens": 2048,
         | 
    	
        src/distilabel_dataset_generator/utils.py
    CHANGED
    
    | @@ -15,7 +15,7 @@ from gradio.oauth import ( | |
| 15 | 
             
            from huggingface_hub import whoami
         | 
| 16 | 
             
            from jinja2 import Environment, meta
         | 
| 17 |  | 
| 18 | 
            -
            from  | 
| 19 |  | 
| 20 | 
             
            _LOGGED_OUT_CSS = ".main_ui_logged_out{opacity: 0.3; pointer-events: none}"
         | 
| 21 |  | 
|  | |
| 15 | 
             
            from huggingface_hub import whoami
         | 
| 16 | 
             
            from jinja2 import Environment, meta
         | 
| 17 |  | 
| 18 | 
            +
            from distilabel_dataset_generator.constants import argilla_client
         | 
| 19 |  | 
| 20 | 
             
            _LOGGED_OUT_CSS = ".main_ui_logged_out{opacity: 0.3; pointer-events: none}"
         | 
| 21 |  | 
 
			
