|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
本帖最后由 峥途 于 2024-11-6 22:36 编辑
终于——敲完了
这次分享的是开源代码,MATLAB软件神经网络在图像识别中的初级应用——手写数字识别:即给了一组手写数字的数据表(可以看一下下方的图例),然后用这个5000个样本的数据库训练这个算法,让它最后判断这个图片是什么数字,属于图像识别里的最简单、经典的了。
最后他的准确率高达95.04%,也算是挺不错的了,大概再多提供一些样本它可以更准确?
然后和昨天的灰度图属于一个数据库,超过2M,无法提交……,所以数据库不能上传嘞……附件里的是“fmincg.m”自带函数。
代码如下:
主函数部分:
- %% 嘿嘿嘿,我胡汉三又回来啦!!!听完音乐会,san值100%~
- %% 清除关闭
- clear;close all;clc;
- %% 读取可视化
- load('ex3data1.mat');
- randIndex = randperm(size(X,1));% 打乱X索引值,随机排序
- seldata = X(randIndex(1:100),:);% 随机取打乱后的前100个样本
- % 可视化
- displayData(seldata);
- %% 函数主体核心oneVsAll
- lamda = 0.05;
- num_labels = 10;
- all_theta = oneVsAll(X,y,num_labels,lamda);
- %% 预测
- P = predict(X,all_theta);
- fprintf('The accuracy is %.2f%%\n',mean(double(y==P))*100);
复制代码
其他部分
- %% 可视化函数
- function [figurePane,display_array] = displayData(X,image_width)
- %% 初始化各项参数
- [m,n] = size(X);
- %% 设置image信息
- if ~exist('image_width','var')||empty(image_width)
- image_width = round(sqrt(n));
- end
- image_height = round(n/image_width);
- %% 设置figure信息
- colormap(gray);
- figure_rows = floor(sqrt(m));
- figure_cols = ceil(sqrt(m));
- %% 初始化display_array
- pad = 1;
- display_array = -ones(pad+(image_height+pad)*figure_rows, ...
- pad+(image_width+pad)*figure_cols);
- %% 将X对应赋值给display_array
- current_image = 1;
- for row = 1:figure_rows
- for col = 1:figure_cols
- max_var = max(X(current_image,:));% 求当前图像的一行中最大值
- display_array(pad+(row-1)*(image_height+pad)+(1:image_height), ...
- pad+(col-1)*(image_width+pad)+(1:image_width))=...
- reshape(X(current_image,:),image_height,image_width)/max_var;
- current_image = current_image+1;
- if current_image > m
- break;
- end
- end
- if current_image >m
- break;
- end
- end
- %% 画图
- figurePane = imagesc(display_array,[-1,1]);
- title('Random handwritten digits');
- axis image off;
- drawnow;
复制代码
- function [J,grad] = lrCostFunction(X,y,theta,lamda)
- [m,~] = size(X);
- h = sigmoid(X*theta);
- J = -(y'*log(h)+(1-y')*log(1-h))/m+lamda*(theta'*theta)/(2*m);
- grad = X'*(h-y)/m+lamda*[0;theta(2:end)]/m;
- end
复制代码
- %% 逻辑回归one-Vs-All
- function all_theta = oneVsAll(X,y,num_labels,lamda)
- %% 初始设置
- [m,n] = size(X);
- X = [ones(m,1),X];
- all_theta = zeros(num_labels,n+1);
- options = optimset('GradObj','On','MaxIter',50);
- for K = 1:num_labels
- init_theta = zeros(n+1,1);
- costFun = @(the)lrCostFunction(X,(y==K),the,lamda);
- theta = fmincg(costFun,init_theta,options);
- all_theta(K,:) = theta';
- end
复制代码
- function P = predict(X,theta)
- [m,~] = size(X);
- X = [ones(m,1),X];
- h = sigmoid(X*theta');
- [~,P] = max(h,[],2);
- end
复制代码
- function g = sigmoid(x)
- g = 1./(1+exp(-x));
- end
复制代码
Ps:(1)就是附件里这个函数MATLAB里面不是自带的,但是一般都会提供给大家使用,这个保存一下就OK了。
(2)这个老登的数据库把所有y的值是0的自动定义成了10,不然他的函数预测部分会出错,我还自己想了半天[0,9]→[1,10]是怎么做的,但事实上,数据库主任提前把这个bug修复了。。
(3)大家感兴趣的可以随便找其他图片进行训练,只不过那个oneVsAll函数部分需要添加一个category,比手写数字这种复杂一丢丢,但是原理一样的啦~
小声逼逼:本来是想一小时独立敲完的,但是虽然差不多一个小时,可是one-Vs-All部分有些函数的输入参量啥的记不大清了,所以最后还是看了大佬给的开源文件,嘤嘤嘤 |
|