a more strict check for activation type and a more reasonable check for type of layer in hypernets
This commit is contained in:
parent
a26fc2834c
commit
c23f666dba
@ -32,10 +32,16 @@ class HypernetworkModule(torch.nn.Module):
|
|||||||
linears = []
|
linears = []
|
||||||
for i in range(len(layer_structure) - 1):
|
for i in range(len(layer_structure) - 1):
|
||||||
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
||||||
|
|
||||||
if activation_func == "relu":
|
if activation_func == "relu":
|
||||||
linears.append(torch.nn.ReLU())
|
linears.append(torch.nn.ReLU())
|
||||||
if activation_func == "leakyrelu":
|
elif activation_func == "leakyrelu":
|
||||||
linears.append(torch.nn.LeakyReLU())
|
linears.append(torch.nn.LeakyReLU())
|
||||||
|
elif activation_func == 'linear' or activation_func is None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
|
||||||
|
|
||||||
if add_layer_norm:
|
if add_layer_norm:
|
||||||
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||||
|
|
||||||
@ -46,7 +52,7 @@ class HypernetworkModule(torch.nn.Module):
|
|||||||
self.load_state_dict(state_dict)
|
self.load_state_dict(state_dict)
|
||||||
else:
|
else:
|
||||||
for layer in self.linear:
|
for layer in self.linear:
|
||||||
if not "ReLU" in layer.__str__():
|
if type(layer) == torch.nn.Linear:
|
||||||
layer.weight.data.normal_(mean=0.0, std=0.01)
|
layer.weight.data.normal_(mean=0.0, std=0.01)
|
||||||
layer.bias.data.zero_()
|
layer.bias.data.zero_()
|
||||||
|
|
||||||
@ -74,7 +80,7 @@ class HypernetworkModule(torch.nn.Module):
|
|||||||
def trainables(self):
|
def trainables(self):
|
||||||
layer_structure = []
|
layer_structure = []
|
||||||
for layer in self.linear:
|
for layer in self.linear:
|
||||||
if not "ReLU" in layer.__str__():
|
if type(layer) == torch.nn.Linear:
|
||||||
layer_structure += [layer.weight, layer.bias]
|
layer_structure += [layer.weight, layer.bias]
|
||||||
return layer_structure
|
return layer_structure
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user