import torch
import torch.nn as nn
import torch.nn.functional as f
from torch.autograd import Variable
import numpy as np
import os
import copy
import random
import pandas
import matplotlib.pyplot as plt
from torch.autograd import Function
def ConvertTensor(vec,grad_flag=True):
return torch.tensor(vec,dtype=torch.float32,requires_grad=grad_flag)
class Generator(nn.Module):
""" Architecture of the Generator, uses res-blocks """
def __init__(self):
super().__init__()
self.counter = 0
self.progress = []
self.NAtom = 12
self.model = self.G_model()
self.optimizer = torch.optim.Adam(self.parameters(),lr=0.001)
self.loss_func = nn.MSELoss()
self.thresh1 = 1.6
self.thresh2 = 12.6
#---------------------------------------------
#the below are used for punishment function
#for flag 1
self.RandomInput = self.Gen_Random_input()
#--------------------------------------------
#self.model.apply(weight_init)
print("Generator model")
print(self.model)
###################################################
def G_model(self):
NAtom = self.NAtom
activation_func = nn.ReLU()
inp_dim = NAtom*3
Normalizer = nn.LayerNorm
model_flag =1
#-----------------------------------------------
if model_flag ==1: #simple
model = nn.Sequential(
nn.Linear(inp_dim, inp_dim),
activation_func,
nn.Linear(inp_dim, inp_dim),
activation_func
)
if model_flag ==101: #simple
model = nn.Sequential(
nn.Linear(inp_dim, inp_dim),
activation_func,
Normalizer(inp_dim)
)
return model
#-----------------------------------------------
#######################################################
def forward(self,input_tensor,weight,bias=0):
out=self.model(input_tensor)
out=torch.matmul(out,weight)+bias
out=out+bias
return out
def Punishment(self,geom):
thresh1 = self.thresh1 #smaller
thresh2 = self.thresh2 #larger
w1 = self.w1
w2 = self.w2
ThreeN = len(geom)
NAtom = int(ThreeN/3)
geom = geom.view(NAtom,3)
bond_length=[]
for i in range(NAtom):
for j in range(i+1,NAtom):
diff = geom[i,:]-geom[j,:]
diff = torch.linalg.norm(diff)
bond_length.append(diff)
bond_length = ConvertTensor(bond_length)
dist_max = torch.max(bond_length)
dist_min = torch.min(bond_length)
#if all bond lengths are good
res1 = 0
res2 = 0
#if any bond is too short or too long
if dist_min < thresh1:
res1 = -(dist_min - thresh1)
if dist_max > thresh2:
res2 = dist_max - thresh2
return w1*res1 + w2*res2
def Gen_Random_input(self):
NAtom = self.NAtom
mean = 0
std = 2
inp = torch.randn(NAtom*3)
input_tensor = inp*std + mean
input_tensor.requires_grad_(True)
return input_tensor
def train(self,target_flag,maxit=1000):
thresh1 = self.thresh1
thresh2 = self.thresh2
input_tensor = self.Gen_Random_input()
init_bias = torch.zeros_like(input_tensor)
init_weight = torch.ones(self.NAtom*3,self.NAtom*3)
self.w1=1
self.w2=1
gen_tensor = input_tensor
for i in range(maxit):
self.optimizer.zero_grad()
#generate fake data
gen_tensor = self.forward(gen_tensor,init_weight,init_bias)
self.gen_tensor = gen_tensor
#calculate the punishment of the fake data
punish = self.Punishment(gen_tensor.detach())
# print(gen_tensor)
print(punish)
#punish = torch.max(punish,torch.zeros_like(punish))
loss = self.loss_func(punish,torch.zeros_like(punish))
#get the network parameters
params = self.state_dict()
keys = list(params.keys())
last_b = copy.deepcopy(params[keys[-1]])
last_w = copy.deepcopy(params[keys[-2]])
#do some check
assert len(last_b) == self.NAtom*3, print("error length of last_b not equal to threeN, len(keys)=",len(last_b))
assert len(last_w) == self.NAtom*3, print("error length of last_w not equal to threeN, len(keys)=",len(last_b))
#update the init_bias and init_weight
init_bias = ConvertTensor(last_b,True)
init_weight = ConvertTensor(last_w,True)
#run the backward training
loss.backward()
self.optimizer.step()
#plot stuff
self.counter += 1
if self.counter%10==0:
self.progress.append(loss.item())
# print(gen_tensor)
return loss,gen_tensor
def plot_progress(self):
df = pandas.DataFrame(self.progress,columns=['loss'])
# df.plot(ylim=[0,0.8],marker='.',grid=True,title="generator_loss" )
df.plot(marker='.',grid=True,title="generator_loss" )
# plt.savefig("generator_loss.png")
plt.show()
#-----
if __name__=="__main__":
generator = Generator()
NAtom=12
init_weight = torch.ones([NAtom*3, NAtom*3])
geom = generator.forward(generator.RandomInput,weight=init_weight)
maxit=500
print("max iteration=",maxit)
loss,gen_tensor = generator.train(target_flag=1,maxit=maxit)
generator.plot_progress()