Upload TFBilma
Browse files- modeling_bilma.py +6 -3
- tf_model.h5 +1 -1
modeling_bilma.py
CHANGED
@@ -9,7 +9,7 @@ from typing import Dict
|
|
9 |
import re
|
10 |
import unicodedata
|
11 |
|
12 |
-
from
|
13 |
|
14 |
# copied from preprocessing.py
|
15 |
BLANK = ' '
|
@@ -37,6 +37,7 @@ class TFBilma(TFPreTrainedModel):
|
|
37 |
|
38 |
def __init__(self, config):
|
39 |
self.seq_max_length = config.seq_max_length
|
|
|
40 |
super().__init__(config)
|
41 |
#if config.weights == "spanish":
|
42 |
# my_resources = importlib_resources.files("hf_bilma")
|
@@ -76,8 +77,10 @@ class TFBilma(TFPreTrainedModel):
|
|
76 |
#if isinstance(tensor, dict) and len(tensor) == 0:
|
77 |
# return self.model(self.dummy_inputs)
|
78 |
ins = tf.cast(inputs["input_ids"], tf.float32)
|
79 |
-
|
80 |
-
|
|
|
|
|
81 |
return output
|
82 |
|
83 |
|
|
|
9 |
import re
|
10 |
import unicodedata
|
11 |
|
12 |
+
from configuration_bilma import BilmaConfig
|
13 |
|
14 |
# copied from preprocessing.py
|
15 |
BLANK = ' '
|
|
|
37 |
|
38 |
def __init__(self, config):
|
39 |
self.seq_max_length = config.seq_max_length
|
40 |
+
self.include_top = config.include_top
|
41 |
super().__init__(config)
|
42 |
#if config.weights == "spanish":
|
43 |
# my_resources = importlib_resources.files("hf_bilma")
|
|
|
77 |
#if isinstance(tensor, dict) and len(tensor) == 0:
|
78 |
# return self.model(self.dummy_inputs)
|
79 |
ins = tf.cast(inputs["input_ids"], tf.float32)
|
80 |
+
if self.include_top:
|
81 |
+
output = {"logits":self.model(ins)}
|
82 |
+
else:
|
83 |
+
output = {"last_hidden_state":self.model(ins)}
|
84 |
return output
|
85 |
|
86 |
|
tf_model.h5
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 156564220
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:75330683b2e51a65402cdd6b87de8d51b817f5924bfa2e8ce2c085d15b3b841b
|
3 |
size 156564220
|