Spaces:
Paused
Paused
ElPlaguister
commited on
Commit
Β·
d753197
1
Parent(s):
47eae50
Fix Model Parameter Init
Browse files- koalpaca.py +2 -0
- model.py +0 -4
- t5.py +2 -0
koalpaca.py
CHANGED
|
@@ -23,6 +23,8 @@ class KoAlpaca(Model):
|
|
| 23 |
self.INPUT_FORMAT = "### μ§λ¬Έ: <INPUT>\n\n### λ΅λ³:"
|
| 24 |
self.model.eval()
|
| 25 |
|
|
|
|
|
|
|
| 26 |
def generate(self, inputs):
|
| 27 |
inputs = self.INPUT_FORMAT.replace('<INPUT>', inputs)
|
| 28 |
output_ids = self.model.generate(
|
|
|
|
| 23 |
self.INPUT_FORMAT = "### μ§λ¬Έ: <INPUT>\n\n### λ΅λ³:"
|
| 24 |
self.model.eval()
|
| 25 |
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
def generate(self, inputs):
|
| 29 |
inputs = self.INPUT_FORMAT.replace('<INPUT>', inputs)
|
| 30 |
output_ids = self.model.generate(
|
model.py
CHANGED
|
@@ -5,10 +5,6 @@ class Model:
|
|
| 5 |
placeholder:str="Input"):
|
| 6 |
self.name = name
|
| 7 |
self.placeholder = placeholder
|
| 8 |
-
self.model = None
|
| 9 |
-
self.tokenizer = None
|
| 10 |
-
self.gen_config = None
|
| 11 |
-
self.INPUT_FORMAT = None
|
| 12 |
self.SPETIAL_TOKENS = ["#νμ#", "#μ²μ#", "#(λ¨μ)μ²μ#", "#(λ¨μ)νμ#", "#(μ¬μ)μ²μ#", "(μ¬μ)νμ"]
|
| 13 |
|
| 14 |
def generate(self, inputs:str) -> str:
|
|
|
|
| 5 |
placeholder:str="Input"):
|
| 6 |
self.name = name
|
| 7 |
self.placeholder = placeholder
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
self.SPETIAL_TOKENS = ["#νμ#", "#μ²μ#", "#(λ¨μ)μ²μ#", "#(λ¨μ)νμ#", "#(μ¬μ)μ²μ#", "(μ¬μ)νμ"]
|
| 9 |
|
| 10 |
def generate(self, inputs:str) -> str:
|
t5.py
CHANGED
|
@@ -21,6 +21,8 @@ class T5(Model):
|
|
| 21 |
self.model.resize_token_embeddings(len(self.tokenizer))
|
| 22 |
self.model.config.max_length = max_target_length
|
| 23 |
self.tokenizer.model_max_length = max_target_length
|
|
|
|
|
|
|
| 24 |
|
| 25 |
def generate(self, inputs):
|
| 26 |
inputs = self.INPUT_FORMAT.replace("<INPUT>", inputs)
|
|
|
|
| 21 |
self.model.resize_token_embeddings(len(self.tokenizer))
|
| 22 |
self.model.config.max_length = max_target_length
|
| 23 |
self.tokenizer.model_max_length = max_target_length
|
| 24 |
+
|
| 25 |
+
super().__init__()
|
| 26 |
|
| 27 |
def generate(self, inputs):
|
| 28 |
inputs = self.INPUT_FORMAT.replace("<INPUT>", inputs)
|