sdiazlor commited on
Commit
93e464a
·
1 Parent(s): c76dc11

typing textcat

Browse files
src/synthetic_dataset_generator/pipelines/textcat.py CHANGED
@@ -85,7 +85,9 @@ def get_prompt_generator():
85
  return prompt_generator
86
 
87
 
88
- def get_textcat_generator(difficulty, clarity, temperature, is_sample):
 
 
89
  generation_kwargs = {
90
  "temperature": temperature,
91
  "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
@@ -102,7 +104,7 @@ def get_textcat_generator(difficulty, clarity, temperature, is_sample):
102
  return textcat_generator
103
 
104
 
105
- def get_labeller_generator(system_prompt, labels, multi_label):
106
  generation_kwargs = {
107
  "temperature": 0.01,
108
  "max_new_tokens": MAX_NUM_TOKENS,
 
85
  return prompt_generator
86
 
87
 
88
+ def get_textcat_generator(
89
+ difficulty: str, clarity: str, temperature: float, is_sample: bool
90
+ ):
91
  generation_kwargs = {
92
  "temperature": temperature,
93
  "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
 
104
  return textcat_generator
105
 
106
 
107
+ def get_labeller_generator(system_prompt: str, labels: List[str], multi_label: bool):
108
  generation_kwargs = {
109
  "temperature": 0.01,
110
  "max_new_tokens": MAX_NUM_TOKENS,