Update modeling_chatsakura.py
Browse files- modeling_chatsakura.py +11 -1
modeling_chatsakura.py
CHANGED
@@ -2,10 +2,20 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
|
4 |
from gptq import *
|
5 |
-
from modelutils import *
|
6 |
from quant import *
|
7 |
from transformers import BloomForCausalLM as LM
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
class SakuraForCausalLM(LM):
|
10 |
def __init__(self,*args,**kwargs):
|
11 |
def noop(*args, **kwargs):
|
|
|
2 |
import torch.nn as nn
|
3 |
|
4 |
from gptq import *
|
5 |
+
# from modelutils import *
|
6 |
from quant import *
|
7 |
from transformers import BloomForCausalLM as LM
|
8 |
|
9 |
+
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
|
10 |
+
if type(module) in layers:
|
11 |
+
return {name: module}
|
12 |
+
res = {}
|
13 |
+
for name1, child in module.named_children():
|
14 |
+
res.update(find_layers(
|
15 |
+
child, layers=layers, name=name + '.' + name1 if name != '' else name1
|
16 |
+
))
|
17 |
+
return res
|
18 |
+
|
19 |
class SakuraForCausalLM(LM):
|
20 |
def __init__(self,*args,**kwargs):
|
21 |
def noop(*args, **kwargs):
|