turns out LayerNorm also has weight and bias and needs to be pre-multiplied and trained for hypernets
This commit is contained in:
parent
e4877722e3
commit
03a1e288c4
@ -52,7 +52,7 @@ class HypernetworkModule(torch.nn.Module):
|
||||
self.load_state_dict(state_dict)
|
||||
else:
|
||||
for layer in self.linear:
|
||||
if type(layer) == torch.nn.Linear:
|
||||
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
||||
layer.weight.data.normal_(mean=0.0, std=0.01)
|
||||
layer.bias.data.zero_()
|
||||
|
||||
@ -80,7 +80,7 @@ class HypernetworkModule(torch.nn.Module):
|
||||
def trainables(self):
|
||||
layer_structure = []
|
||||
for layer in self.linear:
|
||||
if type(layer) == torch.nn.Linear:
|
||||
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
||||
layer_structure += [layer.weight, layer.bias]
|
||||
return layer_structure
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user