MATLAB神经网络手写数字识别
本帖最后由 峥途 于 2024-11-6 22:36 编辑{:10_247:} 终于——敲完了{:10_266:}
这次分享的是开源代码,MATLAB软件神经网络在图像识别中的初级应用——手写数字识别:即给了一组手写数字的数据表(可以看一下下方的图例),然后用这个5000个样本的数据库训练这个算法,让它最后判断这个图片是什么数字,属于图像识别里的最简单、经典的了。
最后他的准确率高达95.04%,也算是挺不错的了,大概再多提供一些样本它可以更准确?
然后和昨天的灰度图属于一个数据库,超过2M,无法提交……,所以数据库不能上传嘞……附件里的是“fmincg.m”自带函数。
代码如下:
主函数部分:
%% 嘿嘿嘿,我胡汉三又回来啦!!!听完音乐会,san值100%~
%% 清除关闭
clear;close all;clc;
%% 读取可视化
load('ex3data1.mat');
randIndex = randperm(size(X,1));% 打乱X索引值,随机排序
seldata = X(randIndex(1:100),:);% 随机取打乱后的前100个样本
% 可视化
displayData(seldata);
%% 函数主体核心oneVsAll
lamda = 0.05;
num_labels = 10;
all_theta = oneVsAll(X,y,num_labels,lamda);
%% 预测
P = predict(X,all_theta);
fprintf('The accuracy is %.2f%%\n',mean(double(y==P))*100);
其他部分
%% 可视化函数
function = displayData(X,image_width)
%% 初始化各项参数
= size(X);
%% 设置image信息
if ~exist('image_width','var')||empty(image_width)
image_width = round(sqrt(n));
end
image_height = round(n/image_width);
%% 设置figure信息
colormap(gray);
figure_rows = floor(sqrt(m));
figure_cols = ceil(sqrt(m));
%% 初始化display_array
pad = 1;
display_array = -ones(pad+(image_height+pad)*figure_rows, ...
pad+(image_width+pad)*figure_cols);
%% 将X对应赋值给display_array
current_image = 1;
for row = 1:figure_rows
for col = 1:figure_cols
max_var = max(X(current_image,:));% 求当前图像的一行中最大值
display_array(pad+(row-1)*(image_height+pad)+(1:image_height), ...
pad+(col-1)*(image_width+pad)+(1:image_width))=...
reshape(X(current_image,:),image_height,image_width)/max_var;
current_image = current_image+1;
if current_image > m
break;
end
end
if current_image >m
break;
end
end
%% 画图
figurePane = imagesc(display_array,[-1,1]);
title('Random handwritten digits');
axis image off;
drawnow;
function = lrCostFunction(X,y,theta,lamda)
= size(X);
h = sigmoid(X*theta);
J = -(y'*log(h)+(1-y')*log(1-h))/m+lamda*(theta'*theta)/(2*m);
grad = X'*(h-y)/m+lamda*/m;
end
%% 逻辑回归one-Vs-All
function all_theta = oneVsAll(X,y,num_labels,lamda)
%% 初始设置
= size(X);
X = ;
all_theta = zeros(num_labels,n+1);
options = optimset('GradObj','On','MaxIter',50);
for K = 1:num_labels
init_theta = zeros(n+1,1);
costFun = @(the)lrCostFunction(X,(y==K),the,lamda);
theta = fmincg(costFun,init_theta,options);
all_theta(K,:) = theta';
end
function P = predict(X,theta)
= size(X);
X = ;
h = sigmoid(X*theta');
[~,P] = max(h,[],2);
end
function g = sigmoid(x)
g = 1./(1+exp(-x));
end
Ps:(1)就是附件里这个函数MATLAB里面不是自带的,但是一般都会提供给大家使用,这个保存一下就OK了。
(2)这个老登的数据库把所有y的值是0的自动定义成了10,不然他的函数预测部分会出错,我还自己想了半天→是怎么做的,但事实上,数据库主任提前把这个bug修复了。。
(3)大家感兴趣的可以随便找其他图片进行训练,只不过那个oneVsAll函数部分需要添加一个category,比手写数字这种复杂一丢丢,但是原理一样的啦~
小声逼逼:本来是想一小时独立敲完的,但是虽然差不多一个小时,可是one-Vs-All部分有些函数的输入参量啥的记不大清了,所以最后还是看了大佬给的开源文件,嘤嘤嘤{:10_296:} 收工~回家!
页:
[1]