鱼C论坛

 找回密码
 立即注册
查看: 2628|回复: 3

[已解决]在github上找的一个神经网络方面的程序,跟大家分享一下,大家帮忙看看有没有语法错误

[复制链接]
发表于 2016-11-30 21:04:31 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 一只菜鸟飞过来 于 2016-11-30 21:06 编辑
  1. import argparse
  2. import json
  3. import time
  4. import datetime
  5. import numpy as np
  6. import code
  7. import socket
  8. import os
  9. import sys
  10. import cPickle as pickle

  11. from imagernn.data_provider import getDataProvider
  12. from imagernn.solver import Solver
  13. from imagernn.imagernn_utils import decodeGenerator, eval_split

  14. def preProBuildWordVocab(sentence_iterator, word_count_threshold):
  15.   # count up all word counts so that we can threshold
  16.   # this shouldnt be too expensive of an operation
  17.   print('preprocessing word counts and creating vocab based on word count threshold %d' % (word_count_threshold))
  18.   t0 = time.time()
  19.   word_counts = {}
  20.   nsents = 0
  21.   for sent in sentence_iterator:
  22.     nsents += 1
  23.     for w in sent['tokens']:
  24.       word_counts[w] = word_counts.get(w, 0) + 1
  25.   vocab = [w for w in word_counts if word_counts[w] >= word_count_threshold]
  26.   print('filtered words from %d to %d in %.2fs' % (len(word_counts),len(vocab),time.time() - t0))

  27.   # with K distinct words:
  28.   # - there are K+1 possible inputs (START token and all the words)
  29.   # - there are K+1 possible outputs (END token and all the words)
  30.   # we use ixtoword to take predicted indeces and map them to words for output visualization
  31.   # we use wordtoix to take raw words and get their index in word vector matrix
  32.   ixtoword = {}
  33.   ixtoword[0] = '.'  # period at the end of the sentence. make first dimension be end token
  34.   wordtoix = {}
  35.   wordtoix['#START#'] = 0 # make first vector be the start token
  36.   ix = 1
  37.   for w in vocab:
  38.     wordtoix[w] = ix
  39.     ixtoword[ix] = w
  40.     ix += 1

  41.   # compute bias vector, which is related to the log probability of the distribution
  42.   # of the labels (words) and how often they occur. We will use this vector to initialize
  43.   # the decoder weights, so that the loss function doesnt show a huge increase in performance
  44.   # very quickly (which is just the network learning this anyway, for the most part). This makes
  45.   # the visualizations of the cost function nicer because it doesn't look like a hockey stick.
  46.   # for example on Flickr8K, doing this brings down initial perplexity from ~2500 to ~170.
  47.   word_counts['.'] = nsents
  48.   bias_init_vector = np.array([1.0*word_counts[ixtoword[i]] for i in ixtoword])
  49.   bias_init_vector /= np.sum(bias_init_vector) # normalize to frequencies
  50.   bias_init_vector = np.log(bias_init_vector)
  51.   bias_init_vector -= np.max(bias_init_vector) # shift to nice numeric range
  52.   return wordtoix, ixtoword, bias_init_vector

  53. def RNNGenCost(batch, model, params, misc):
  54.   """ cost function, returns cost and gradients for model """
  55.   regc = params['regc'] # regularization cost
  56.   BatchGenerator = decodeGenerator(params)
  57.   wordtoix = misc['wordtoix']

  58.   # forward the RNN on each image sentence pair
  59.   # the generator returns a list of matrices that have word probabilities
  60.   # and a list of cache objects that will be needed for backprop
  61.   Ys, gen_caches = BatchGenerator.forward(batch, model, params, misc, predict_mode = False)

  62.   # compute softmax costs for all generated sentences, and the gradients on top
  63.   loss_cost = 0.0
  64.   dYs = []
  65.   logppl = 0.0
  66.   logppln = 0
  67.   for i,pair in enumerate(batch):
  68.     img = pair['image']
  69.     # ground truth indeces for this sentence we expect to see
  70.     gtix = [ wordtoix[w] for w in pair['sentence']['tokens'] if w in wordtoix ]
  71.     gtix.append(0) # don't forget END token must be predicted in the end!
  72.     # fetch the predicted probabilities, as rows
  73.     Y = Ys[i]
  74.     maxes = np.amax(Y, axis=1, keepdims=True)
  75.     e = np.exp(Y - maxes) # for numerical stability shift into good numerical range
  76.     P = e / np.sum(e, axis=1, keepdims=True)
  77.     loss_cost += - np.sum(np.log(1e-20 + P[range(len(gtix)),gtix])) # note: add smoothing to not get infs
  78.     logppl += - np.sum(np.log2(1e-20 + P[range(len(gtix)),gtix])) # also accumulate log2 perplexities
  79.     logppln += len(gtix)

  80.     # lets be clever and optimize for speed here to derive the gradient in place quickly
  81.     for iy,y in enumerate(gtix):
  82.       P[iy,y] -= 1 # softmax derivatives are pretty simple
  83.     dYs.append(P)

  84.   # backprop the RNN
  85.   grads = BatchGenerator.backward(dYs, gen_caches)

  86.   # add L2 regularization cost and gradients
  87.   reg_cost = 0.0
  88.   if regc > 0:   
  89.     for p in misc['regularize']:
  90.       mat = model[p]
  91.       reg_cost += 0.5 * regc * np.sum(mat * mat)
  92.       grads[p] += regc * mat

  93.   # normalize the cost and gradient by the batch size
  94.   batch_size = len(batch)
  95.   reg_cost /= batch_size
  96.   loss_cost /= batch_size
  97.   for k in grads: grads[k] /= batch_size

  98.   # return output in json
  99.   out = {}
  100.   out['cost'] = {'reg_cost' : reg_cost, 'loss_cost' : loss_cost, 'total_cost' : loss_cost + reg_cost}
  101.   out['grad'] = grads
  102.   out['stats'] = { 'ppl2' : 2 ** (logppl / logppln)}
  103.   return out

  104. def main(params):
  105.   batch_size = params['batch_size']
  106.   dataset = params['dataset']
  107.   word_count_threshold = params['word_count_threshold']
  108.   do_grad_check = params['do_grad_check']
  109.   max_epochs = params['max_epochs']
  110.   host = socket.gethostname() # get computer hostname

  111.   # fetch the data provider
  112.   dp = getDataProvider(dataset)

  113.   misc = {} # stores various misc items that need to be passed around the framework

  114.   # go over all training sentences and find the vocabulary we want to use, i.e. the words that occur
  115.   # at least word_count_threshold number of times
  116.   misc['wordtoix'], misc['ixtoword'], bias_init_vector = preProBuildWordVocab(dp.iterSentences('train'), word_count_threshold)

  117.   # delegate the initialization of the model to the Generator class
  118.   BatchGenerator = decodeGenerator(params)
  119.   init_struct = BatchGenerator.init(params, misc)
  120.   model, misc['update'], misc['regularize'] = (init_struct['model'], init_struct['update'], init_struct['regularize'])

  121.   # force overwrite here. This is a bit of a hack, not happy about it
  122.   model['bd'] = bias_init_vector.reshape(1, bias_init_vector.size)

  123.   print('model init done.')
  124.   print('model has keys: ' + ', '.join(model.keys()))
  125.   print('updating: ' + ', '.join( '%s [%dx%d]' % (k, model[k].shape[0], model[k].shape[1]) for k in misc['update']))
  126.   print('updating: ' + ', '.join( '%s [%dx%d]' % (k, model[k].shape[0], model[k].shape[1]) for k in misc['regularize']))
  127.   print('number of learnable parameters total: %d' % (sum(model[k].shape[0] * model[k].shape[1] for k in misc['update']))

  128.   if params.get('init_model_from',''):
  129.         # load checkpoint
  130.         checkpoint = pickle.load(open(params['init_model_from'], 'rb'))
  131.         model = checkpoint['model'] # overwrite the model
  132.         
  133.   # initialize the Solver and the cost function
  134.   solver = Solver()
  135.   def costfun(batch, model):
  136.     # wrap the cost function to abstract some things away from the Solver
  137.     return RNNGenCost(batch, model, params, misc)

  138.   # calculate how many iterations we need
  139.   num_sentences_total = dp.getSplitSize('train', ofwhat = 'sentences')
  140.   num_iters_one_epoch = num_sentences_total / batch_size
  141.   max_iters = max_epochs * num_iters_one_epoch
  142.   eval_period_in_epochs = params['eval_period']
  143.   eval_period_in_iters = max(1, int(num_iters_one_epoch * eval_period_in_epochs))
  144.   abort = False
  145.   top_val_ppl2 = -1
  146.   smooth_train_ppl2 = len(misc['ixtoword']) # initially size of dictionary of confusion
  147.   val_ppl2 = len(misc['ixtoword'])
  148.   last_status_write_time = 0 # for writing worker job status reports
  149.   json_worker_status = {}
  150.   json_worker_status['params'] = params
  151.   json_worker_status['history'] = []
  152.   for it in xrange(max_iters):
  153.     if abort: break
  154.     t0 = time.time()
  155.     # fetch a batch of data
  156.     batch = [dp.sampleImageSentencePair() for i in xrange(batch_size)]
  157.     # evaluate cost, gradient and perform parameter update
  158.     step_struct = solver.step(batch, model, costfun, **params)
  159.     cost = step_struct['cost']
  160.     dt = time.time() - t0

  161.     # print training statistics
  162.     train_ppl2 = step_struct['stats']['ppl2']
  163.     smooth_train_ppl2 = 0.99 * smooth_train_ppl2 + 0.01 * train_ppl2 # smooth exponentially decaying moving average
  164.     if it == 0: smooth_train_ppl2 = train_ppl2 # start out where we start out
  165.     epoch = it * 1.0 / num_iters_one_epoch
  166.     print '%d/%d batch done in %.3fs. at epoch %.2f. loss cost = %f, reg cost = %f, ppl2 = %.2f (smooth %.2f)' \
  167.           % (it, max_iters, dt, epoch, cost['loss_cost'], cost['reg_cost'], \
  168.              train_ppl2, smooth_train_ppl2)

  169.     # perform gradient check if desired, with a bit of a burnin time (10 iterations)
  170.     if it == 10 and do_grad_check:
  171.       print 'disabling dropout for gradient check...'
  172.       params['drop_prob_encoder'] = 0
  173.       params['drop_prob_decoder'] = 0
  174.       solver.gradCheck(batch, model, costfun)
  175.       print 'done gradcheck, exitting.'
  176.       sys.exit() # hmmm. probably should exit here

  177.     # detect if loss is exploding and kill the job if so
  178.     total_cost = cost['total_cost']
  179.     if it == 0:
  180.       total_cost0 = total_cost # store this initial cost
  181.     if total_cost > total_cost0 * 2:
  182.       print 'Aboring, cost seems to be exploding. Run gradcheck? Lower the learning rate?'
  183.       abort = True # set the abort flag, we'll break out

  184.     # logging: write JSON files for visual inspection of the training
  185.     tnow = time.time()
  186.     if tnow > last_status_write_time + 60*1: # every now and then lets write a report
  187.       last_status_write_time = tnow
  188.       jstatus = {}
  189.       jstatus['time'] = datetime.datetime.now().isoformat()
  190.       jstatus['iter'] = (it, max_iters)
  191.       jstatus['epoch'] = (epoch, max_epochs)
  192.       jstatus['time_per_batch'] = dt
  193.       jstatus['smooth_train_ppl2'] = smooth_train_ppl2
  194.       jstatus['val_ppl2'] = val_ppl2 # just write the last available one
  195.       jstatus['train_ppl2'] = train_ppl2
  196.       json_worker_status['history'].append(jstatus)
  197.       status_file = os.path.join(params['worker_status_output_directory'], host + '_status.json')
  198.       try:
  199.         json.dump(json_worker_status, open(status_file, 'w'))
  200.       except Exception, e: # todo be more clever here
  201.         print 'tried to write worker status into %s but got error:' % (status_file, )
  202.         print e

  203.     # perform perplexity evaluation on the validation set and save a model checkpoint if it's good
  204.     is_last_iter = (it+1) == max_iters
  205.     if (((it+1) % eval_period_in_iters) == 0 and it < max_iters - 5) or is_last_iter:
  206.       val_ppl2 = eval_split('val', dp, model, params, misc) # perform the evaluation on VAL set
  207.       print 'validation perplexity = %f' % (val_ppl2, )
  208.       
  209.       # abort training if the perplexity is no good
  210.       min_ppl_or_abort = params['min_ppl_or_abort']
  211.       if val_ppl2 > min_ppl_or_abort and min_ppl_or_abort > 0:
  212.         print 'aborting job because validation perplexity %f < %f' % (val_ppl2, min_ppl_or_abort)
  213.         abort = True # abort the job

  214.       write_checkpoint_ppl_threshold = params['write_checkpoint_ppl_threshold']
  215.       if val_ppl2 < top_val_ppl2 or top_val_ppl2 < 0:
  216.         if val_ppl2 < write_checkpoint_ppl_threshold or write_checkpoint_ppl_threshold < 0:
  217.           # if we beat a previous record or if this is the first time
  218.           # AND we also beat the user-defined threshold or it doesnt exist
  219.           top_val_ppl2 = val_ppl2
  220.           filename = 'model_checkpoint_%s_%s_%s_%.2f.p' % (dataset, host, params['fappend'], val_ppl2)
  221.           filepath = os.path.join(params['checkpoint_output_directory'], filename)
  222.           checkpoint = {}
  223.           checkpoint['it'] = it
  224.           checkpoint['epoch'] = epoch
  225.           checkpoint['model'] = model
  226.           checkpoint['params'] = params
  227.           checkpoint['perplexity'] = val_ppl2
  228.           checkpoint['wordtoix'] = misc['wordtoix']
  229.           checkpoint['ixtoword'] = misc['ixtoword']
  230.           try:
  231.             pickle.dump(checkpoint, open(filepath, "wb"))
  232.             print 'saved checkpoint in %s' % (filepath, )
  233.           except Exception, e: # todo be more clever here
  234.             print 'tried to write checkpoint into %s but got error: ' % (filepat, )
  235.             print e


  236. if __name__ == "__main__":

  237.   parser = argparse.ArgumentParser()

  238.   # global setup settings, and checkpoints
  239.   parser.add_argument('-d', '--dataset', dest='dataset', default='flickr8k', help='dataset: flickr8k/flickr30k')
  240.   parser.add_argument('-a', '--do_grad_check', dest='do_grad_check', type=int, default=0, help='perform gradcheck? program will block for visual inspection and will need manual user input')
  241.   parser.add_argument('--fappend', dest='fappend', type=str, default='baseline', help='append this string to checkpoint filenames')
  242.   parser.add_argument('-o', '--checkpoint_output_directory', dest='checkpoint_output_directory', type=str, default='cv/', help='output directory to write checkpoints to')
  243.   parser.add_argument('--worker_status_output_directory', dest='worker_status_output_directory', type=str, default='status/', help='directory to write worker status JSON blobs to')
  244.   parser.add_argument('--write_checkpoint_ppl_threshold', dest='write_checkpoint_ppl_threshold', type=float, default=-1, help='ppl threshold above which we dont bother writing a checkpoint to save space')
  245.   parser.add_argument('--init_model_from', dest='init_model_from', type=str, default='', help='initialize the model parameters from some specific checkpoint?')
  246.   
  247.   # model parameters
  248.   parser.add_argument('--generator', dest='generator', type=str, default='lstm', help='generator to use')
  249.   parser.add_argument('--image_encoding_size', dest='image_encoding_size', type=int, default=256, help='size of the image encoding')
  250.   parser.add_argument('--word_encoding_size', dest='word_encoding_size', type=int, default=256, help='size of word encoding')
  251.   parser.add_argument('--hidden_size', dest='hidden_size', type=int, default=256, help='size of hidden layer in generator RNNs')
  252.   # lstm-specific params
  253.   parser.add_argument('--tanhC_version', dest='tanhC_version', type=int, default=0, help='use tanh version of LSTM?')
  254.   # rnn-specific params
  255.   parser.add_argument('--rnn_relu_encoders', dest='rnn_relu_encoders', type=int, default=0, help='relu encoders before going to RNN?')
  256.   parser.add_argument('--rnn_feed_once', dest='rnn_feed_once', type=int, default=0, help='feed image to the rnn only single time?')

  257.   # optimization parameters
  258.   parser.add_argument('-c', '--regc', dest='regc', type=float, default=1e-8, help='regularization strength')
  259.   parser.add_argument('-m', '--max_epochs', dest='max_epochs', type=int, default=50, help='number of epochs to train for')
  260.   parser.add_argument('--solver', dest='solver', type=str, default='rmsprop', help='solver type: vanilla/adagrad/adadelta/rmsprop')
  261.   parser.add_argument('--momentum', dest='momentum', type=float, default=0.0, help='momentum for vanilla sgd')
  262.   parser.add_argument('--decay_rate', dest='decay_rate', type=float, default=0.999, help='decay rate for adadelta/rmsprop')
  263.   parser.add_argument('--smooth_eps', dest='smooth_eps', type=float, default=1e-8, help='epsilon smoothing for rmsprop/adagrad/adadelta')
  264.   parser.add_argument('-l', '--learning_rate', dest='learning_rate', type=float, default=1e-3, help='solver learning rate')
  265.   parser.add_argument('-b', '--batch_size', dest='batch_size', type=int, default=100, help='batch size')
  266.   parser.add_argument('--grad_clip', dest='grad_clip', type=float, default=5, help='clip gradients (normalized by batch size)? elementwise. if positive, at what threshold?')
  267.   parser.add_argument('--drop_prob_encoder', dest='drop_prob_encoder', type=float, default=0.5, help='what dropout to apply right after the encoder to an RNN/LSTM')
  268.   parser.add_argument('--drop_prob_decoder', dest='drop_prob_decoder', type=float, default=0.5, help='what dropout to apply right before the decoder in an RNN/LSTM')

  269.   # data preprocessing parameters
  270.   parser.add_argument('--word_count_threshold', dest='word_count_threshold', type=int, default=5, help='if a word occurs less than this number of times in training data, it is discarded')

  271.   # evaluation parameters
  272.   parser.add_argument('-p', '--eval_period', dest='eval_period', type=float, default=1.0, help='in units of epochs, how often do we evaluate on val set?')
  273.   parser.add_argument('--eval_batch_size', dest='eval_batch_size', type=int, default=100, help='for faster validation performance evaluation, what batch size to use on val img/sentences?')
  274.   parser.add_argument('--eval_max_images', dest='eval_max_images', type=int, default=-1, help='for efficiency we can use a smaller number of images to get validation error')
  275.   parser.add_argument('--min_ppl_or_abort', dest='min_ppl_or_abort', type=float , default=-1, help='if validation perplexity is below this threshold the job will abort')

  276.   args = parser.parse_args()
  277.   params = vars(args) # convert to ordinary dict
  278.   print 'parsed parameters:'
  279.   print json.dumps(params, indent = 2)
  280.   main(params)
复制代码


程序运行到148行的if语句时出现错误:invalid syntax
最佳答案
2016-12-1 09:07:35
前一句 括号没配对
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

发表于 2016-12-1 09:07:35 | 显示全部楼层    本楼为最佳答案   
前一句 括号没配对
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2016-12-1 11:03:27 | 显示全部楼层
SixPy 发表于 2016-12-1 09:07
前一句 括号没配对

还真是,我一直盯着那个if语句看了,忽略了上面,你能看出来这个程序是干嘛用的吗?
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2017-7-7 14:46:46 | 显示全部楼层
SixPy 发表于 2016-12-1 09:07
前一句 括号没配对

我就说为什么这头像这么熟悉,一个月才反应过来是猫梓酱
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-5-18 12:14

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表