|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
以下是我的代码,为什么我的punish以及init_weight等函数在循环中没有变化?- import torch
- import torch.nn as nn
- import torch.nn.functional as f
- from torch.autograd import Variable
- import numpy as np
- import os
- import copy
- import random
- import pandas
- import matplotlib.pyplot as plt
- from torch.autograd import Function
- def ConvertTensor(vec,grad_flag=True):
- return torch.tensor(vec,dtype=torch.float32,requires_grad=grad_flag)
- def Punishment(geom,generator,flag=1):
- if flag==1:
- thresh1 = generator.thresh1 #smaller
- thresh2 = generator.thresh2 #larger
- w1 = generator.w1
- w2 = generator.w2
-
- ThreeN = len(geom)
- NAtom = int(ThreeN/3)
- geom = geom.view(NAtom,3)
- bond_length=[]
- for i in range(NAtom):
- for j in range(i+1,NAtom):
- diff = geom[i,:]-geom[j,:]
- diff = torch.linalg.norm(diff)
- bond_length.append(diff)
-
- bond_length = ConvertTensor(bond_length)
- dist_max = torch.max(bond_length)
- dist_min = torch.min(bond_length)
- #if all bond lengths are good
- res1 = 0
- res2 = 0
- #if any bond is too short or too long
- if dist_min < thresh1:
- res1 = -(dist_min - thresh1)
- if dist_max > thresh2:
- res2 = dist_max - thresh2
-
- return w1*res1 + w2*res2
- class Generator(nn.Module):
- """ Architecture of the Generator, uses res-blocks """
- def __init__(self):
- super().__init__()
-
- self.counter = 0
- self.progress = []
- self.NAtom = 12
- self.model = self.G_model()
- self.optimizer = torch.optim.Adam(self.parameters(),lr=0.001)
- self.thresh1 = 1.6
- self.thresh2 = 12.6
- #---------------------------------------------
- #the below are used for punishment function
- #for flag 1
- self.RandomInput = self.Gen_Random_input()
-
- #--------------------------------------------
- #self.model.apply(weight_init)
- print("Generator model")
- print(self.model)
- ###################################################
- def G_model(self):
- NAtom = self.NAtom
- activation_func = nn.Sigmoid()
- inp_dim = NAtom*3
- Normalizer = nn.LayerNorm
- model_flag =1
- #-----------------------------------------------
- if model_flag ==1: #simple
- model = nn.Sequential(
- nn.Linear(inp_dim, inp_dim),
- activation_func
- )
- if model_flag ==101: #simple
- model = nn.Sequential(
- nn.Linear(inp_dim, inp_dim),
- activation_func,
- Normalizer(inp_dim)
- )
- return model
- #-----------------------------------------------
- #######################################################
- def forward(self,input_tensor,weight,bias=0):
- out=self.model(input_tensor)
- out=torch.matmul(out,weight)+bias
- out=out+bias
- return out
- def Gen_Random_input(self):
- NAtom = self.NAtom
- mean = 0
- std = 2
- inp = torch.randn(NAtom*3)
- input_tensor = inp*std + mean
- input_tensor.requires_grad_(True)
- return input_tensor
-
- def train(self,target_flag,maxit=1000):
- thresh1 = self.thresh1
- thresh2 = self.thresh2
- input_tensor = self.Gen_Random_input()
- init_bias = torch.zeros_like(input_tensor)
- init_weight = torch.ones(self.NAtom*3,self.NAtom*3)
- self.w1=1
- self.w2=1
- for i in range(maxit):
-
- #generate fake data
- gen_tensor = self.forward(input_tensor,init_weight,init_bias)
- self.gen_tensor = gen_tensor
- #calculate the punishment of the fake data
- punish = Punishment(gen_tensor.detach(),self,flag=1)
- print(punish)
- #punish = torch.max(punish,torch.zeros_like(punish))
- loss = punish
- #get the network parameters
- params = self.state_dict()
- keys = list(params.keys())
- last_b = copy.deepcopy(params[keys[-1]])
- last_w = copy.deepcopy(params[keys[-2]])
- #do some check
- assert len(last_b) == self.NAtom*3, print("error length of last_b not equal to threeN, len(keys)=",len(last_b))
- assert len(last_w) == self.NAtom*3, print("error length of last_w not equal to threeN, len(keys)=",len(last_b))
- #update the init_bias and init_weight
- init_bias = ConvertTensor(last_b,True)
- init_weight = ConvertTensor(last_w,True)
- #run the backward training
- self.optimizer.zero_grad()
- loss.backward()
- self.optimizer.step()
- if punish > 10:
- init_weight = torch.zeros_like(init_weight)
- #plot stuff
- self.counter += 1
- if self.counter%10==0:
- self.progress.append(loss.item())
- # print(gen_tensor)
- return loss,gen_tensor
- def plot_progress(self):
- df = pandas.DataFrame(self.progress,columns=['loss'])
- # df.plot(ylim=[0,0.8],marker='.',grid=True,title="generator_loss" )
- df.plot(marker='.',grid=True,title="generator_loss" )
- # plt.savefig("generator_loss.png")
- plt.show()
- #-----
- if __name__=="__main__":
- generator = Generator()
- NAtom=12
- init_weight = torch.ones([NAtom*3, NAtom*3])
- geom = generator.forward(generator.RandomInput,weight=init_weight)
- maxit=500
- print("max iteration=",maxit)
- loss,gen_tensor = generator.train(target_flag=1,maxit=maxit)
-
- generator.plot_progress()
复制代码 |
|