|
楼主 |
发表于 2023-11-29 12:24:05
|
显示全部楼层
现在我把gen_CTable 放到了foward里,但还是没有解决之前的问题。
另一方面,gen_CRow_T就是把gen_CTable进行扁平化后的结果。
import torch
import torch.nn as nn
import torch.nn.functional as f
from torch.autograd import Variable
import numpy as np
import os
import copy
import random
import pandas
import matplotlib.pyplot as plt
from torch.autograd import Function
from G02_dataset import myDataset
from G01_parameters import G_Parameter, MolPara
#in this attempt, use the random as input
class MyLoss2(nn.Module):
def __init__(self,ref):
super(MyLoss2,self).__init__()
self.ref = ref
def forward(self,gen_CTable):
loss = ((gen_CTable - self.ref) ** 2).mean()
return loss
class Generator(nn.Module):
""" Architecture of the Generator, uses res-blocks """
def __init__(self,para):
super().__init__()
self.activation_func = para.activation_func
self.loss_func = para.loss_func
self.mat_shape = para.mat_shape
self.increment = para.increment
self.lr = para.lr
self.model_flag = para.model_flag
self.Nlayer = para.Nlayer
self.batch_size = para.batch_size
self.Nsample = para.Nsample
self.NAtom = para.NAtom
self.thresh = para.thresh
self.ref_CTable_T = torch.tensor(para.ref_CTable_T,dtype=torch.float32,requires_grad=False)
self.ref_CRow_T = self.ref_CTable_T.view(1,self.NAtom*self.NAtom)
self.model = self.G_model()
#no loss function according to the green book
# self.model.apply(weight_init)
self.optimiser = para.optimizer(self.parameters(),lr=self.lr)
self.counter = 0
self.progress = []
##################################################3
def G_model(self):
model_flag = self.model_flag
Nlayer = self.Nlayer
mat_shape = self.mat_shape
increment = self.increment
activation_func = self.activation_func
if model_flag==0: #MNIST
model = nn.Sequential(
nn.Linear(100,128,bias=True),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(128,256,bias=True),
nn.BatchNorm1d(256, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256,512,bias=True),
nn.BatchNorm1d(512, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512,1024,bias=True),
nn.BatchNorm1d(1024, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024,784,bias=True),
nn.Tanh()
)
print("generator",model)
#-----------------------------------------------
if model_flag ==1: #simple
inp_dim = mat_shape[0]*mat_shape[1]
n1 = inp_dim
n2 = inp_dim + increment
n3 = inp_dim + increment*3
n4 = inp_dim + increment*2
model = nn.Sequential(
nn.Linear(n1, n2),
activation_func,
nn.Linear(n2, n3),
# nn.BatchNorm1d(n3,0.8),
activation_func,
nn.Linear(n3,n2),
activation_func,
# nn.BatchNorm1d(n2,0.8),
activation_func,
nn.Linear(n2,n1),
nn.Tanh()
)
#----------------------------------------------
if model_flag==2:#increase then decrease
module_list=[]
inp_dim = mat_shape[0]*mat_shape[1]
for i in range(Nlayer):
module_list.append(nn.Linear(inp_dim,inp_dim+increment))
module_list.append(activation_func)
inp_dim+=increment
for i in range(Nlayer):
module_list.append(nn.Linear(inp_dim,inp_dim-increment))
module_list.append(activation_func)
inp_dim-=increment
model = nn.Sequential(*module_list)
#-----------------------------------------------------
if model_flag==3: #decrease then increase
module_list=[]
inp_dim = mat_shape[0]*mat_shape[1]
for i in range(Nlayer):
module_list.append(nn.Linear(inp_dim,inp_dim-increment))
module_list.append(activation_func)
inp_dim-=increment
for i in range(Nlayer):
module_list.append(nn.Linear(inp_dim,inp_dim+increment))
module_list.append(activation_func)
inp_dim+=increment
model = nn.Sequential(*module_list)
return model
#######################################################
def forward(self):
NAtom = self.NAtom
input_tensor = torch.rand(3*NAtom)
out=self.model(input_tensor)
gen_CTable,gen_CRow_T = Tensor_Connectivity(self.NAtom, out)
return out,gen_CRow_T
def train(self):
self.optimiser.zero_grad()
gen_geom,gen_CRow_T = self.forward()
flag_loss=3
if flag_loss==1: #use custom loss function
batch_geom = [gen_geom] #to maxium use existing code
loss = self.myLoss(batch_geom)
loss.requires_grad_(True)
if flag_loss==2: #use custom loss class
gen_CTable = torch.tensor(Tensor_Connectivity(self.NAtom, gen_geom),dtype=torch.float32)
gen_CTable = gen_CTable.view(1,self.NAtom*self.NAtom)
loss = MyLoss.apply(gen_CTable,self.ref_CTable_T)
loss.requires_grad_(True)
if flag_loss==3:
loss_fn = MyLoss2(self.ref_CRow_T)
loss = loss_fn(gen_CRow_T)
self.counter += 1
if self.counter%10==0:
self.progress.append(loss.item())
loss.backward()
self.optimiser.step()
return loss,gen_geom
def plot_progress(self):
df = pandas.DataFrame(self.progress,columns=['loss'])
# df.plot(ylim=[0,0.8],marker='.',grid=True,title="generator_loss" )
df.plot(marker='.',grid=True,title="generator_loss" )
plt.savefig("generator_loss.png")
plt.show()
#-----
def CalcBondLength(geom,i,j):
ri=geom[i,:]
rj=geom[j,:]
rij=ri-rj
rij=torch.linalg.norm(rij)
return rij
def CalcAllBondLength(geom):
NAtom=geom.shape[0]
res=[]
bond_lengths={}
for i in range(NAtom-1):
for j in range(i+1,NAtom):
rij = CalcBondLength(geom,i,j)
res.append([[i,j],rij])
bond_lengths[str([i,j])]=rij
return res,bond_lengths
def Tensor_Connectivity(NAtom,flat_geom_T,thresh=1.6):
geom_T = flat_geom_T.view(NAtom,3)
bondlengthlist,bond_length_dict=CalcAllBondLength(geom_T)
ConnectTable = np.mat(np.zeros((NAtom,NAtom)),dtype=int)
for i in range(NAtom):
for j in range(NAtom):
key = str([i,j])
if key in bond_length_dict.keys():
value = bond_length_dict[key]
else:
continue
if value<=thresh:
ConnectTable[i,j]=1
ConnectTable = np.reshape(ConnectTable,(1,NAtom*NAtom))
gen_CRow_T = torch.tensor(ConnectTable,dtype=torch.float32,requires_grad=True)
return ConnectTable,gen_CRow_T
if __name__=="__main__":
molinfo = MolPara()
para = G_Parameter(molinfo)
generator = Generator(para)
print(molinfo)
print(para)
Epoch=500
NAtom=molinfo.NAtom
for i in range(Epoch):
loss,out = generator.train()
if i%50==0:
print("Epoch ",i, "loss ", loss.item())
print(out.detach())
if abs(loss.item())<0.001:
print("hurray")
break
generator.plot_progress()
out = generator.forward()
|
|