下面是我的网络类定义。不知道有没有什么问题。
另外怎么查看网络参数有没有更新?class Generator(nn.Module):
""" Architecture of the Generator, uses res-blocks """
def __init__(self,para):
super().__init__()
self.info = para
self.info.counter = 0
self.info.progress = []
self.model = self.G_model()
self.optimizer = para.optimizer(self.parameters(),lr=para.lr)
self.flag_network = para.Gflag_network # =1, use network, =2 not use network
self.RandomInput = self.Gen_Random_input()
#self.model.apply(weight_init)
print("Generator model")
print(self.model)
##################################################3
def G_model(self):
info = self.info
model_flag = info.model_flag
Nlayer = info.Nlayer
NAtom = info.NAtom
increment = info.increment
activation_func = info.activation_func
inp_dim = NAtom*3
#-----------------------------------------------
#----------------------------------------------
if model_flag==2:#increase then decrease
module_list=[]
for i in range(Nlayer):
module_list.append(nn.Linear(inp_dim,inp_dim+increment))
module_list.append(activation_func)
inp_dim+=increment
for i in range(Nlayer):
module_list.append(nn.Linear(inp_dim,inp_dim-increment))
module_list.append(activation_func)
inp_dim-=increment
module_list.append(nn.Linear(inp_dim,inp_dim))
model = nn.Sequential(*module_list)
#-----------------------------------------------------
return model
#-----------------------------------------------
#######################################################
def forward(self,input_tensor,flag=1):
#use random input
if flag==1:
out=self.model(input_tensor)
if flag==2: #not use any network
out = input_tensor
return out
def Gen_Random_input(self):
NAtom = self.info.NAtom
mean = self.info.configs[0].mean
std = self.info.configs[0].std
inp = torch.randn(NAtom*3)
input_tensor = inp*std + mean
input_tensor.requires_grad_(True)
return input_tensor
def train(self,thresh1,thresh2,maxit=1000):
for i in range(maxit):
self.optimizer.zero_grad()
input_tensor = self.RandomInput
print(input_tensor)
gen_tensor = self.forward(input_tensor, self.flag_network)
punish = 0
punish = Calc_bond_length(gen_tensor,thresh1,thresh2)
# targets = torch.ones_like(gen_tensor)
# loss = self.info.loss_func(gen_tensor,targets)+punish
loss = punish
self.info.counter += 1
if self.info.counter%10==0:
self.info.progress.append(loss.item())
loss.backward()
self.optimizer.step()
if loss.item()<2:
print("break")
break
return loss,gen_tensor
|