鱼C论坛

 找回密码
 立即注册
楼主: xjtu_wong

[交流] JamesWONG的打卡计划

[复制链接]
 楼主| 发表于 2019-3-13 23:39:40 | 显示全部楼层
已经被魔法方法搞晕的小王
1EC029D0-E4A5-4E02-8C65-4A0ED812AE02.png
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-3-14 23:27:11 | 显示全部楼层
关于写代码,还是背一些方法比较好用。。。晚安
F2BD686B-6ACB-4E8F-B291-0A60558F3026.png
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-3-18 09:40:18 | 显示全部楼层
学院运动会小拿了个季军🥉哈哈哈,春天来了要多运动
E2C60770-783D-47A2-B7DD-7AA9FBE178E2.jpeg
6A8AAAE3-9E6E-4C2C-A97A-378EFE849132.png
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-3-19 23:08:37 | 显示全部楼层
7683B233-CFB9-4E6E-A444-ADEAECAACE45.jpeg
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-3-21 21:25:46 | 显示全部楼层
0321_2.jpg 0321_1.PNG
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-3-24 23:29:48 | 显示全部楼层
rl.bmp nakagami.bmp lognormal.bmp
今天的仿真成果
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-3-26 21:44:19 | 显示全部楼层
  1. %{
  2. 作者:JamesWONG
  3. %}
  4. clc
  5. clear all;
  6. %图像的横坐标为平均接收信噪比(???=10log10?),纵坐标为归一化容量(C/B)
  7. for each_SNR_dB = linspace(5,30,11)
  8.     index = ((each_SNR_dB-5)/2.5+1);
  9.     AWGN_lognormal(index)=log2(1+10^(0.1*(each_SNR_dB+8^2*log(10)/20)));
  10.     RX_CSI_lognormal_integrand=(@(x)log2(1+x).*10./((2*pi)^0.5.*x*8*log(10)).*exp(-(10*log10(x)-each_SNR_dB).^2/(2*(8)^2)));
  11.     RX_CSI_lognormal(index)=quadgk(RX_CSI_lognormal_integrand,0,inf);
  12.     TXnRX_CSI_lognormal_threshold_integrand=@(y) (quadgk((@(x)(1./y-1./x).*10./((2*pi)^0.5.*x*8*log(10)).*exp(-(10*log10(x)-each_SNR_dB).^2/(2*(8)^2))),y,inf)-1);
  13.     TXnRX_CSI_lognormal_threshold(index)=fzero(TXnRX_CSI_lognormal_threshold_integrand,1);  %求收发两端都已知CSI时的中断门限
  14.     TXnRX_CSI_lognormal_integrand = (@(x) log2(x/TXnRX_CSI_lognormal_threshold(index)).*10./((2*pi)^0.5.*x*8*log(10)).*exp(-(10*log10(x)-each_SNR_dB).^2/(2*(8)^2)));
  15.     TXnRX_CSI_lognormal(index) = quadgk(TXnRX_CSI_lognormal_integrand,TXnRX_CSI_lognormal_threshold(index),inf);
  16.     Zero_outage_integrand=@(y) (quadgk((@(x)(y./x).*10./((2*pi)^0.5.*x*8*log(10)).*exp(-(10*log10(x)-each_SNR_dB).^2/(2*(8)^2))),0,inf)-1);
  17.     SNR_Zero_outage(index)=fzero(Zero_outage_integrand,1);  %求零中断容量恒定信噪比
  18.     Zero_outage_logmormal(index)=log2(1+SNR_Zero_outage(index));
  19.     for k=(each_SNR_dB-10):(each_SNR_dB+10)  %最大中断容量时对中断门限进行遍历
  20.         Max_outage_integrand=@(y) (quadgk((@(x)(y./x).*10./((2*pi)^0.5.*x*8*log(10)).*exp(-(10*log10(x)-each_SNR_dB).^2/(2*(8)^2))),10^(0.1*k),inf)-1);
  21.         SNR_Max_outage(51-each_SNR_dB+k)=fzero(Max_outage_integrand,1);  %最大中断容量时求恒定信噪比
  22.         outrage_capacity(51-each_SNR_dB+k)=log2(1+SNR_Max_outage(51-each_SNR_dB+k))*(quadgk((@(x) 10./((2*pi)^0.5.*x*8*log(10)).*exp(-(10*log10(x)-each_SNR_dB).^2/(2*(8)^2))),10^(0.1*k),inf));
  23.     end
  24.     Maximum_outage_lognormal(index)=max(outrage_capacity); %取最大的中断容量
  25.     [ar_r,rcr_r,trcr_r,zon_r,mon_r] = Nakagami_fade(1,each_SNR_dB);
  26.     AWGN_Rayleigh(index) = ar_r;
  27.     RX_CSI_Rayleigh(index) = rcr_r;
  28.     TXnRX_CSI_Rayleigh(index) = trcr_r;
  29.     Zero_outage_Rayleigh(index) = zon_r;
  30.     Maximum_outage_Rayleigh(index) = mon_r;
  31.     [ar,rcr,trcr,zon,mon] = Nakagami_fade(2,each_SNR_dB);
  32.     AWGN_Nakagami(index) = ar;
  33.     RX_CSI_Nakagami(index) = rcr;
  34.     TXnRX_CSI_Nakagami(index) = trcr;
  35.     Zero_outage_Nakagami(index) = zon;
  36.     Maximum_outage_Nakagami(index) = mon;
  37. end
  38. figure(1)
  39. E_SNR_dB=linspace(5,30,11);
  40. plot(E_SNR_dB,AWGN_lognormal,'-*b',E_SNR_dB,RX_CSI_lognormal,'-+r',E_SNR_dB,TXnRX_CSI_lognormal,'-og',E_SNR_dB,Zero_outage_logmormal,'-.*y',E_SNR_dB,Maximum_outage_lognormal,'-ks');
  41. hold on
  42. legend('AWGN信道容量','RX CSI的香农容量','TX/RX CSI的香农容量','零中断容量','最大中断容量');
  43. xlabel('平均接收信噪比(dB)');
  44. ylabel('C/B (bit/s/Hz)');
  45. title('对数正态衰落下的信道容量');
  46. figure(2)
  47. E_SNR_dB=linspace(5,30,11);
  48. plot(E_SNR_dB,AWGN_Rayleigh,'-*b',E_SNR_dB,RX_CSI_Rayleigh,'-+r',E_SNR_dB,TXnRX_CSI_Rayleigh,'-og',E_SNR_dB,Zero_outage_Rayleigh,'-.*y',E_SNR_dB,Maximum_outage_Rayleigh,'-ks');hold on
  49. legend('AWGN信道容量','RX CSI的香农容量','TX/RX CSI的香农容量','零中断容量','最大中断容量');
  50. xlabel('平均接收信噪比(dB)');
  51. ylabel('C/B (bit/s/Hz)');
  52. title('瑞利衰落下的信道容量');
  53. figure(3)
  54. E_SNR_dB=linspace(5,30,11);
  55. plot(E_SNR_dB,AWGN_Nakagami,'-*b',E_SNR_dB,RX_CSI_Nakagami,'-+r',E_SNR_dB,TXnRX_CSI_Nakagami,'-og',E_SNR_dB,Zero_outage_Nakagami,'-.*y',E_SNR_dB,Maximum_outage_Nakagami,'-ks');hold on
  56. legend('AWGN信道容量','RX CSI的香农容量','TX/RX CSI的香农容量','零中断容量','最大中断容量');
  57. xlabel('平均接收信噪比(dB)');
  58. ylabel('C/B (bit/s/Hz)');
  59. title('Nakagami衰落下的信道容量');
复制代码

MATLAB程序,仿真无线信道容量对比(Goldsmith的教材)
  1. %{
  2. Nakagami_fade函数功能:给定Nakagami衰落的参数m以及平均接受信噪比,
  3.                       返回五种信道下的归一化容量C/B
  4. 示例:
  5. >>
  6. [ar_r,rcr_r,trcr_r,zon_r,mon_r] = Nakagami_fade(1,5)
  7. 得到:
  8. ar_r =2.0574
  9. rcr_r =1.7160
  10. trcr_r =1.8451
  11. zon_r = 0.0531
  12. mon_r =1.6577
  13. %}
  14. function[AWGN_Rayleigh,RX_CSI_Rayleigh,TXnRX_CSI_Rayleigh,Zero_outage_Nakagami,Maximum_outage_Nakagami] = Nakagami_fade(m,E_SNR_dB)
  15.         AWGN_Rayleigh=log2(1+10^(0.1*E_SNR_dB));
  16.         RX_CSI_Rayleigh_integrand=(@(x)log2(1+x).*(m ./10^(0.1*E_SNR_dB)).^m .*x.^(m-1) /gamma(m) .*exp(-(x.*m ./10^(0.1*E_SNR_dB))));
  17.         RX_CSI_Rayleigh=quadgk(RX_CSI_Rayleigh_integrand,0,inf);  
  18.         TXnRX_CSI_Rayleigh_threshold_integrand=@(y) (quadgk((@(x)(1./y-1./x).*(m ./10^(0.1*E_SNR_dB)).^m .*x.^(m-1) /gamma(m) .*exp(-(x.*m ./10^(0.1*E_SNR_dB)))),y,inf)-1);
  19.         TXnRX_CSI_Rayleigh_threshold=fzero(TXnRX_CSI_Rayleigh_threshold_integrand,0.55);  %求收发两端都已知CSI时的中断门限,通过迭代初值所得
  20.         TXnRX_CSI_Rayleigh_integrand=(@(x) log2(x/TXnRX_CSI_Rayleigh_threshold).*(m ./10^(0.1*E_SNR_dB)).^m .*x.^(m-1) /gamma(m) .*exp(-(x.*m ./10^(0.1*E_SNR_dB))));
  21.         TXnRX_CSI_Rayleigh=quadgk(TXnRX_CSI_Rayleigh_integrand,TXnRX_CSI_Rayleigh_threshold,inf);  
  22.         Zero_outage_integrand=@(y) (quadgk((@(x)(y./x).*(m ./10^(0.1*E_SNR_dB)).^m .*x.^(m-1) /gamma(m) .*exp(-(x.*m ./10^(0.1*E_SNR_dB)))),0,inf)-1);
  23.         SNR_Zero_outage=fzero(Zero_outage_integrand,1);   %求零中断容量恒定信噪比
  24.         Zero_outage_Nakagami=log2(1+SNR_Zero_outage);
  25.         for k=(E_SNR_dB-10):(E_SNR_dB+10)  %最大中断容量时对中断门限进行遍历
  26.             Max_outage_integrand=@(y) (quadgk((@(x)(y./x).*(m ./10^(0.1*E_SNR_dB)).^m .*x.^(m-1) /gamma(m) .*exp(-(x.*m ./10^(0.1*E_SNR_dB)))),10^(0.1*k),inf)-1);
  27.             SNR_Max_outage(31-E_SNR_dB+k)=fzero(Max_outage_integrand,1);  %最大中断容量时求恒定信噪比
  28.             max_outage_capacity(31-E_SNR_dB+k)=log2(1+SNR_Max_outage(31-E_SNR_dB+k))*(quadgk((@(x) (m ./10^(0.1*E_SNR_dB)).^m .*x.^(m-1)/gamma(m).*exp(-(x.*m ./10^(0.1*E_SNR_dB)))),10^(0.1*k),inf));
  29.         end
  30.         Maximum_outage_Nakagami=max(max_outage_capacity); %取最大的中断容量
  31. end
复制代码
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-3-27 23:06:40 | 显示全部楼层
0327_2.PNG 0327_1.PNG
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-4-25 21:32:40 | 显示全部楼层
0425_1.PNG
1、复习随机过程第二章第二节
2、英语打卡
最近某东买书打折。
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-6-4 11:10:06 | 显示全部楼层
Eric Matthes课后习题9.7:
  1. #-*-coding:utf-8-*-
  2. #author:JamesW
  3. class User():
  4.     def __init__(self,first_name,last_name,**profile):
  5.         self.first_name = first_name
  6.         self.last_name = last_name
  7.         self.profile = profile
  8.         profile = {}
  9.         profile['first_name'] = self.first_name
  10.         profile['Last_name'] = self.last_name
  11.         for key,value in profile.items():
  12.             profile[key] = value

  13.     def describe_user(self):
  14.         for each in self.profile:
  15.             print(each+':'+self.profile[each])

  16.     def greet_user(self):
  17.         formatted_name = self.first_name + ' ' +self.last_name
  18.         print('Hello,%s!'%formatted_name)

  19. class Admin(User):
  20.     def __init__(self,first_name,last_name,*can_do,**profile):
  21.         super().__init__(first_name,last_name,**profile)
  22.         self.privileges = can_do
  23.     def show_privileges(self):
  24.         print('User %s %s is a Administrater, :'%(self.first_name,self.last_name))
  25.         for each in self.privileges:
  26.             print('He '+each)

  27. if __name__ == '__main__':
  28.     ad1 = Admin('James','Wong','can add post','can delete post',age='22',hometown='cz')
  29.     ad1.show_privileges()
  30.     ad1.describe_user()
  31.     ad1.greet_user()
复制代码
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-6-4 11:35:20 | 显示全部楼层
Eric Matthes 习题9.8,将实例用作属性
  1. #-*-coding:utf-8-*-
  2. #author:JamesW
  3. class User():
  4.     def __init__(self,first_name,last_name,**profile):
  5.         self.first_name = first_name
  6.         self.last_name = last_name
  7.         self.profile = profile
  8.         profile = {}
  9.         profile['first_name'] = self.first_name
  10.         profile['Last_name'] = self.last_name
  11.         for key,value in profile.items():
  12.             profile[key] = value

  13.     def describe_user(self):
  14.         for each in self.profile:
  15.             print(each+':'+self.profile[each])

  16.     def greet_user(self):
  17.         formatted_name = self.first_name + ' ' +self.last_name
  18.         print('Hello,%s!'%formatted_name)

  19. class Privileges():
  20.     def __init__(self,*can_do):
  21.         self.privileges = can_do
  22.     def show_privileges(self):
  23.         print('This is a Administrater, :')
  24.         for each in self.privileges:
  25.             print('He '+each)

  26. class Admin(User):
  27.     def __init__(self,first_name,last_name,**profile):
  28.         super().__init__(first_name,last_name,**profile)
  29.         self.ad1 = Privileges('can add post','can delete post')

  30. if __name__ == '__main__':
  31.     ad = Admin('James','Wong',age='22',hometown='cz')
  32.     ad.ad1.show_privileges()
  33.     ad.describe_user()
  34.     ad.greet_user()
复制代码
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-6-4 20:11:07 | 显示全部楼层
Json 和 Pickle:
python的pickle模块实现了python的所有数据序列和反序列化。基本上功能使用和JSON模块没有太大区别,方法也同样是dumps/dump和loads/load。cPickle是pickle模块的C语言编译版本相对速度更快。
与JSON不同的是pickle不是用于多种语言间的数据传输,它仅作为python对象的持久化或者python程序间进行互相传输对象的方法,因此它支持了python所有的数据类型。
---------------------
作者:shuyededenghou
来源:CSDN
原文:https://blog.csdn.net/shuyededenghou/article/details/75923353
版权声明:本文为博主原创文章,转载请附上博文链接!
JSON(JavaScript Object Notation) 是一种轻量级的数据交换格式。它基于ECMAScript的一个子集。 JSON采用完全独立于语言的文本格式,但是也使用了类似于C语言家族的习惯(包括C、C++、Java、JavaScript、Perl、Python等)。这些特性使JSON成为理想的数据交换语言。易于人阅读和编写,同时也易于机器解析和生成(一般用于提升网络传输速率)。
JSON在python中分别由list和dict组成。
---------------------
作者:shuyededenghou
来源:CSDN
原文:https://blog.csdn.net/shuyededenghou/article/details/75923353
版权声明:本文为博主原创文章,转载请附上博文链接!
json与pickle模块是将Python中的数据进行序列化,便于存取与传输。

处理文件时,考虑多行数据的存取,如换行符的使用
json序列化时,是以字符串的形式存取,所以对于字典的存取时,关键字或值经过序列化与反序列化后,类型都是字符串,有些表面看起来是数值,其实还是字符串,在使用时一定要注意类型的转换
---------------------
作者:shuyededenghou
来源:CSDN
原文:https://blog.csdn.net/shuyededenghou/article/details/75923353
版权声明:本文为博主原创文章,转载请附上博文链接!


Eric Matthes习题10_11
  1. #-*-coding:utf-8-*-
  2. #author:JamesW
  3. import json
  4. num = input("plz input ur favorite number:")
  5. filename = 'num.json'
  6. with open(filename,'w') as f:
  7.     json.dump(num,f)
  8.     print("file saved!")

  9. with open(filename) as rf:
  10.     fav_num = json.load(rf)
  11.     print("ur favorite number is:"+fav_num)
复制代码
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-6-5 09:27:13 | 显示全部楼层
Eric Marrhes 习题10_13
  1. #-*-coding:utf-8-*-
  2. #author:JamesW
  3. import json

  4. def get_stored_uesrname():
  5.     filename = 'username.json'
  6.     try:
  7.         with open(filename) as f_obj:
  8.             username = json.load(f_obj)
  9.     except FileNotFoundError:
  10.         return None
  11.     else:
  12.         return username

  13. def get_new_username():
  14.     username = input("What's ur name?")
  15.     filename = 'username.json'
  16.     with open(filename,'w') as f_obj:
  17.         json.dump(username,f_obj)
  18.     return username

  19. def greet_user():
  20.     """问候用户,指出其姓名"""
  21.     username = get_stored_uesrname()
  22.     if username:
  23.         flag = input('R u %s ?(Y/N)'%username)
  24.         if flag == 'Y':
  25.             print('Welcome back,'+username+'!')
  26.         elif flag == 'N':
  27.             username = get_new_username()
  28.     else:
  29.         username = get_new_username()
  30.         print("We'll remeber you when you come back,"+username+'!')

  31. if __name__ == '__main__':
  32.     greet_user()
复制代码
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-8-8 11:25:13 | 显示全部楼层
分享多元线性回归梯度下降法更新权值推导 SGD.jpg
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-8-28 10:23:55 | 显示全部楼层
根据奥莱利《深度学习入门》中学习到的神经网络,其中激活函数和MNIST数据集以经预先保存为包,网络的权值与偏置预先设定。分类结果准确率为0.9352.

  1. # coding: utf-8
  2. import sys, os
  3. sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
  4. import numpy as np
  5. import pickle
  6. from dataset.mnist import load_mnist
  7. from common.functions import sigmoid, softmax


  8. def get_data():
  9.     (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
  10.     return x_test, t_test #x为数据,t为标签


  11. def init_network():
  12.     with open("sample_weight.pkl", 'rb') as f:
  13.         network = pickle.load(f)
  14.     return network


  15. def predict(network, x):
  16.     W1, W2, W3 = network['W1'], network['W2'], network['W3']
  17.     b1, b2, b3 = network['b1'], network['b2'], network['b3']

  18.     a1 = np.dot(x, W1) + b1 # (1*784)*(784*50)+50
  19.     z1 = sigmoid(a1)
  20.     a2 = np.dot(z1, W2) + b2 # (1*50)*(50*100)+100
  21.     z2 = sigmoid(a2)
  22.     a3 = np.dot(z2, W3) + b3 # (1*100)*(100*10)+10
  23.     y = softmax(a3) # 1*10 (one-hot)

  24.     return y


  25. x, t = get_data() #x保存数据,t保存标签
  26. network = init_network() #导入预存的权值矩阵!难怪有这么高的精度。
  27. print(network.keys()) #['b2', 'W1', 'b1', 'W2', 'W3', 'b3']
  28. accuracy_cnt = 0
  29. for i in range(len(x)):
  30.     y = predict(network, x[i])
  31.     p= np.argmax(y) # 获取概率最高的元素的索引
  32.     if p == t[i]:
  33.         accuracy_cnt += 1

  34. print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
复制代码


偏置: sample_weight.zip (165.56 KB, 下载次数: 0) common.zip (11.96 KB, 下载次数: 0)



想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-8-28 10:48:30 | 显示全部楼层
引入批处理后,实现将原来的单个训练改进为多组数据同时训练,精度不变的情况下缩短运算时间。
局部晚期统计,可以缩短三分之二的时间。
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Wed Aug 28 10:33:32 2019

  4. @author: u
  5. """

  6. # coding: utf-8
  7. import time
  8. import sys, os
  9. sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
  10. import numpy as np
  11. import pickle
  12. from dataset.mnist import load_mnist
  13. from common.functions import sigmoid, softmax

  14. batch_size = 100

  15. def get_data():
  16.     (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
  17.     return x_test, t_test #x为数据,t为标签


  18. def init_network():
  19.     with open("sample_weight.pkl", 'rb') as f:
  20.         network = pickle.load(f)
  21.     return network


  22. def predict(network, x):
  23.     W1, W2, W3 = network['W1'], network['W2'], network['W3']
  24.     b1, b2, b3 = network['b1'], network['b2'], network['b3']

  25.     a1 = np.dot(x, W1) + b1 # (1*784)*(784*50)+50
  26.     z1 = sigmoid(a1)
  27.     a2 = np.dot(z1, W2) + b2 # (1*50)*(50*100)+100
  28.     z2 = sigmoid(a2)
  29.     a3 = np.dot(z2, W3) + b3 # (1*100)*(100*10)+10
  30.     y = softmax(a3) # 1*10 (one-hot)

  31.     return y

  32. time_start=time.time()
  33. x, t = get_data() #x保存数据,t保存标签
  34. network = init_network() #导入预存的权值矩阵!难怪有这么高的精度。
  35. print(network.keys()) #['b2', 'W1', 'b1', 'W2', 'W3', 'b3']

  36. accuracy_cnt = 0
  37. for i in range(0, len(x), batch_size):
  38.        x_batch = x[i:i+batch_size]
  39.        #print(x_batch.shape) #(100, 784)
  40.        y_batch = predict(network, x_batch)
  41.        #print(y_batch.shape) #(100,10)
  42.        p = np.argmax(y_batch, axis=1)
  43.        accuracy_cnt += np.sum(p == t[i:i+batch_size])
  44.       
  45. time_end=time.time()
  46. print('totally cost',time_end-time_start)
  47. print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
复制代码
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-8-30 10:07:41 | 显示全部楼层
加入SGD后具有学习参数的神经网络。此处计算梯度采用数值计算法,可以看到运行时间较慢:
  1. total time cost 38.359713554382324
复制代码

若采用梯度反向传播则可以提高代码效率。
  1. # coding: utf-8
  2. import sys, os
  3. sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. import time
  7. from dataset.mnist import load_mnist
  8. from two_layer_net import TwoLayerNet

  9. # 读入数据
  10. (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)

  11. network = TwoLayerNet(input_size=784, hidden_size=50, output_size=10)

  12. iters_num = 10000  # 适当设定循环的次数
  13. train_size = x_train.shape[0]
  14. batch_size = 100
  15. learning_rate = 0.1

  16. train_loss_list = []
  17. train_acc_list = []
  18. test_acc_list = []

  19. iter_per_epoch = max(train_size / batch_size, 1)
  20. time_start=time.time()
  21. time_before = time_start

  22. for i in range(iters_num):
  23.     batch_mask = np.random.choice(train_size, batch_size)
  24.     x_batch = x_train[batch_mask]
  25.     t_batch = t_train[batch_mask]
  26.    
  27.     # 计算梯度
  28.     #grad = network.numerical_gradient(x_batch, t_batch)
  29.     grad = network.gradient(x_batch, t_batch)
  30.    
  31.     # 更新参数
  32.     for key in ('W1', 'b1', 'W2', 'b2'):
  33.         network.params[key] -= learning_rate * grad[key]
  34.    
  35.     loss = network.loss(x_batch, t_batch)
  36.     train_loss_list.append(loss)
  37.    
  38.     if i % iter_per_epoch == 0:
  39.         train_acc = network.accuracy(x_train, t_train)
  40.         test_acc = network.accuracy(x_test, t_test)
  41.         train_acc_list.append(train_acc)
  42.         test_acc_list.append(test_acc)
  43.         print("train acc, test acc | " + str(train_acc) + ", " + str(test_acc))
  44.         time_end_this = time.time()
  45.         print('this epoch time cost',time_end_this-time_before)
  46.         time_before = time_end_this

  47. time_end = time.time()
  48. print('total time cost',time_end-time_start)
  49. # 绘制图形
  50. markers = {'train': 'o', 'test': 's'}
  51. x = np.arange(len(train_acc_list))
  52. plt.plot(x, train_acc_list, label='train acc')
  53. plt.plot(x, test_acc_list, label='test acc', linestyle='--')
  54. plt.xlabel("epochs")
  55. plt.ylabel("accuracy")
  56. plt.ylim(0, 1.0)
  57. plt.legend(loc='lower right')
  58. plt.show()
复制代码
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-8-30 10:10:11 | 显示全部楼层
数值法计算梯度的类:
  1. # coding: utf-8
  2. import numpy as np

  3. def _numerical_gradient_1d(f, x):
  4.     h = 1e-4 # 0.0001
  5.     grad = np.zeros_like(x)
  6.    
  7.     for idx in range(x.size):
  8.         tmp_val = x[idx]
  9.         x[idx] = float(tmp_val) + h
  10.         fxh1 = f(x) # f(x+h)
  11.         
  12.         x[idx] = tmp_val - h
  13.         fxh2 = f(x) # f(x-h)
  14.         grad[idx] = (fxh1 - fxh2) / (2*h)
  15.         
  16.         x[idx] = tmp_val # 还原值
  17.         
  18.     return grad


  19. def numerical_gradient_2d(f, X):
  20.     if X.ndim == 1:
  21.         return _numerical_gradient_1d(f, X)
  22.     else:
  23.         grad = np.zeros_like(X)
  24.         
  25.         for idx, x in enumerate(X):
  26.             grad[idx] = _numerical_gradient_1d(f, x)
  27.         
  28.         return grad


  29. def numerical_gradient(f, x):
  30.     h = 1e-4 # 0.0001
  31.     grad = np.zeros_like(x)
  32.    
  33.     it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite'])
  34.     while not it.finished:
  35.         idx = it.multi_index
  36.         tmp_val = x[idx]
  37.         x[idx] = float(tmp_val) + h
  38.         fxh1 = f(x) # f(x+h)
  39.         
  40.         x[idx] = tmp_val - h
  41.         fxh2 = f(x) # f(x-h)
  42.         grad[idx] = (fxh1 - fxh2) / (2*h)
  43.         
  44.         x[idx] = tmp_val # 还原值
  45.         it.iternext()   
  46.         
  47.     return grad
复制代码
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-9-6 17:04:07 | 显示全部楼层
1.读完
R. Wen, J. Tang, T. Q. S. Quek, G. Feng, G. Wang, and W. Tan, “Robust Network Slicing in Software-Defined 5G Networks,” 2017 IEEE Glob. Commun. Conf. GLOBECOM 2017 - Proc., vol. 2018-January, pp. 1–6, 2018.
并总结;
2.重构网络SDN架构实现前两章;
3.行书入门一页
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-9-8 14:55:45 | 显示全部楼层
1.
重构网络SDN架构实现
读完第二章;
2.读完
R. Wen et al., “On Robustness of Network Slicing for Next-Generation Mobile Networks,” IEEE Trans. Commun., vol. 67, no. 1, pp. 430–444, 2019.

3.行书入门一页;
4.组会PPT提纲。
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-5-15 01:54

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表