Update main.py
Browse files
main.py
CHANGED
@@ -15,6 +15,7 @@ class Item(BaseModel):
|
|
15 |
temperature: float = 0.6
|
16 |
max_new_tokens: int = 1024
|
17 |
top_p: float = 0.95
|
|
|
18 |
seed : int = 42
|
19 |
|
20 |
app = FastAPI()
|
@@ -56,7 +57,7 @@ def generate(item: Item):
|
|
56 |
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
57 |
]
|
58 |
|
59 |
-
outputs = model.generate(input_ids, eos_token_id=terminators, do_sample=True)
|
60 |
response = outputs[0][input_ids.shape[-1]:]
|
61 |
return tokenizer.decode(response, skip_special_tokens=True)
|
62 |
|
|
|
15 |
temperature: float = 0.6
|
16 |
max_new_tokens: int = 1024
|
17 |
top_p: float = 0.95
|
18 |
+
repetition_penalty: float = 1.0
|
19 |
seed : int = 42
|
20 |
|
21 |
app = FastAPI()
|
|
|
57 |
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
58 |
]
|
59 |
|
60 |
+
outputs = model.generate(input_ids, eos_token_id=terminators, do_sample=True, **generate_kwargs,)
|
61 |
response = outputs[0][input_ids.shape[-1]:]
|
62 |
return tokenizer.decode(response, skip_special_tokens=True)
|
63 |
|