Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,505 Bytes
3e648fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import torch
import torch.nn as nn
from einops import rearrange
from net.transformer_utils import *
# Cross Attention Block
class CAB(nn.Module):
def __init__(self, dim, num_heads, bias):
super(CAB, self).__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.q = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias)
self.kv = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias)
self.kv_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
def forward(self, x, y):
b, c, h, w = x.shape
q = self.q_dwconv(self.q(x))
kv = self.kv_dwconv(self.kv(y))
k, v = kv.chunk(2, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = nn.functional.softmax(attn,dim=-1)
out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out
# Intensity Enhancement Layer
class IEL(nn.Module):
def __init__(self, dim, ffn_expansion_factor=2.66, bias=False):
super(IEL, self).__init__()
hidden_features = int(dim*ffn_expansion_factor)
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
self.dwconv1 = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, groups=hidden_features, bias=bias)
self.dwconv2 = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, groups=hidden_features, bias=bias)
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
self.Tanh = nn.Tanh()
def forward(self, x):
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
x1 = self.Tanh(self.dwconv1(x1)) + x1
x2 = self.Tanh(self.dwconv2(x2)) + x2
x = x1 * x2
x = self.project_out(x)
return x
# Lightweight Cross Attention
class HV_LCA(nn.Module):
def __init__(self, dim,num_heads, bias=False):
super(HV_LCA, self).__init__()
self.gdfn = IEL(dim) # IEL and CDL have same structure
self.norm = LayerNorm(dim)
self.ffn = CAB(dim, num_heads, bias)
def forward(self, x, y):
x = x + self.ffn(self.norm(x),self.norm(y))
x = self.gdfn(self.norm(x))
return x
class I_LCA(nn.Module):
def __init__(self, dim,num_heads, bias=False):
super(I_LCA, self).__init__()
self.norm = LayerNorm(dim)
self.gdfn = IEL(dim)
self.ffn = CAB(dim, num_heads, bias=bias)
def forward(self, x, y):
x = x + self.ffn(self.norm(x),self.norm(y))
x = x + self.gdfn(self.norm(x))
return x
|