Update modeling_chatsakura.py
Browse files- modeling_chatsakura.py +1 -11
modeling_chatsakura.py
CHANGED
@@ -2,19 +2,9 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
|
4 |
from .gptq import *
|
5 |
-
|
6 |
from .quant import *
|
7 |
from transformers import BloomForCausalLM as LM
|
8 |
-
|
9 |
-
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
|
10 |
-
if type(module) in layers:
|
11 |
-
return {name: module}
|
12 |
-
res = {}
|
13 |
-
for name1, child in module.named_children():
|
14 |
-
res.update(find_layers(
|
15 |
-
child, layers=layers, name=name + '.' + name1 if name != '' else name1
|
16 |
-
))
|
17 |
-
return res
|
18 |
|
19 |
class SakuraForCausalLM(LM):
|
20 |
def __init__(self,*args,**kwargs):
|
|
|
2 |
import torch.nn as nn
|
3 |
|
4 |
from .gptq import *
|
5 |
+
from .modelutils import *
|
6 |
from .quant import *
|
7 |
from transformers import BloomForCausalLM as LM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
class SakuraForCausalLM(LM):
|
10 |
def __init__(self,*args,**kwargs):
|