|

楼主 |
发表于 2024-3-25 09:04:25
|
显示全部楼层
下面是我的网络类定义。不知道有没有什么问题。
另外怎么查看网络参数有没有更新?
- class Generator(nn.Module):
- """ Architecture of the Generator, uses res-blocks """
- def __init__(self,para):
- super().__init__()
-
- self.info = para
- self.info.counter = 0
- self.info.progress = []
- self.model = self.G_model()
- self.optimizer = para.optimizer(self.parameters(),lr=para.lr)
- self.flag_network = para.Gflag_network # =1, use network, =2 not use network
- self.RandomInput = self.Gen_Random_input()
- #self.model.apply(weight_init)
- print("Generator model")
- print(self.model)
- ##################################################3
- def G_model(self):
- info = self.info
- model_flag = info.model_flag
- Nlayer = info.Nlayer
- NAtom = info.NAtom
- increment = info.increment
- activation_func = info.activation_func
- inp_dim = NAtom*3
- #-----------------------------------------------
- #----------------------------------------------
- if model_flag==2:#increase then decrease
- module_list=[]
- for i in range(Nlayer):
- module_list.append(nn.Linear(inp_dim,inp_dim+increment))
- module_list.append(activation_func)
- inp_dim+=increment
-
- for i in range(Nlayer):
- module_list.append(nn.Linear(inp_dim,inp_dim-increment))
- module_list.append(activation_func)
- inp_dim-=increment
- module_list.append(nn.Linear(inp_dim,inp_dim))
- model = nn.Sequential(*module_list)
- #-----------------------------------------------------
- return model
- #-----------------------------------------------
- #######################################################
- def forward(self,input_tensor,flag=1):
- #use random input
- if flag==1:
- out=self.model(input_tensor)
-
- if flag==2: #not use any network
- out = input_tensor
-
- return out
- def Gen_Random_input(self):
- NAtom = self.info.NAtom
- mean = self.info.configs[0].mean
- std = self.info.configs[0].std
- inp = torch.randn(NAtom*3)
- input_tensor = inp*std + mean
- input_tensor.requires_grad_(True)
- return input_tensor
-
- def train(self,thresh1,thresh2,maxit=1000):
- for i in range(maxit):
- self.optimizer.zero_grad()
- input_tensor = self.RandomInput
- print(input_tensor)
- gen_tensor = self.forward(input_tensor, self.flag_network)
-
- punish = 0
- punish = Calc_bond_length(gen_tensor,thresh1,thresh2)
-
- # targets = torch.ones_like(gen_tensor)
- # loss = self.info.loss_func(gen_tensor,targets)+punish
- loss = punish
-
- self.info.counter += 1
- if self.info.counter%10==0:
- self.info.progress.append(loss.item())
-
-
- loss.backward()
- self.optimizer.step()
- if loss.item()<2:
- print("break")
- break
- return loss,gen_tensor
复制代码 |
|