typing textcat
Browse files
src/synthetic_dataset_generator/pipelines/textcat.py
CHANGED
@@ -85,7 +85,9 @@ def get_prompt_generator():
|
|
85 |
return prompt_generator
|
86 |
|
87 |
|
88 |
-
def get_textcat_generator(
|
|
|
|
|
89 |
generation_kwargs = {
|
90 |
"temperature": temperature,
|
91 |
"max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
|
@@ -102,7 +104,7 @@ def get_textcat_generator(difficulty, clarity, temperature, is_sample):
|
|
102 |
return textcat_generator
|
103 |
|
104 |
|
105 |
-
def get_labeller_generator(system_prompt, labels, multi_label):
|
106 |
generation_kwargs = {
|
107 |
"temperature": 0.01,
|
108 |
"max_new_tokens": MAX_NUM_TOKENS,
|
|
|
85 |
return prompt_generator
|
86 |
|
87 |
|
88 |
+
def get_textcat_generator(
|
89 |
+
difficulty: str, clarity: str, temperature: float, is_sample: bool
|
90 |
+
):
|
91 |
generation_kwargs = {
|
92 |
"temperature": temperature,
|
93 |
"max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
|
|
|
104 |
return textcat_generator
|
105 |
|
106 |
|
107 |
+
def get_labeller_generator(system_prompt: str, labels: List[str], multi_label: bool):
|
108 |
generation_kwargs = {
|
109 |
"temperature": 0.01,
|
110 |
"max_new_tokens": MAX_NUM_TOKENS,
|