鱼C论坛

 找回密码
 立即注册
查看: 718|回复: 1

[学习笔记] MATLAB神经网络手写数字识别

[复制链接]
发表于 2024-11-6 22:34:47 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 峥途 于 2024-11-6 22:36 编辑

终于——敲完了
这次分享的是开源代码,MATLAB软件神经网络在图像识别中的初级应用——手写数字识别:即给了一组手写数字的数据表(可以看一下下方的图例),然后用这个5000个样本的数据库训练这个算法,让它最后判断这个图片是什么数字,属于图像识别里的最简单、经典的了。
最后他的准确率高达95.04%,也算是挺不错的了,大概再多提供一些样本它可以更准确?
然后和昨天的灰度图属于一个数据库,超过2M,无法提交……,所以数据库不能上传嘞……附件里的是“fmincg.m”自带函数。
代码如下:
主函数部分:

  1. %% 嘿嘿嘿,我胡汉三又回来啦!!!听完音乐会,san值100%~
  2. %% 清除关闭
  3. clear;close all;clc;
  4. %% 读取可视化
  5. load('ex3data1.mat');
  6. randIndex = randperm(size(X,1));% 打乱X索引值,随机排序
  7. seldata = X(randIndex(1:100),:);% 随机取打乱后的前100个样本

  8. % 可视化
  9. displayData(seldata);
  10. %% 函数主体核心oneVsAll
  11. lamda = 0.05;
  12. num_labels = 10;
  13. all_theta = oneVsAll(X,y,num_labels,lamda);
  14. %% 预测
  15. P = predict(X,all_theta);
  16. fprintf('The accuracy is %.2f%%\n',mean(double(y==P))*100);
复制代码



其他部分
  1. %% 可视化函数
  2. function [figurePane,display_array] = displayData(X,image_width)
  3. %% 初始化各项参数
  4. [m,n] = size(X);

  5. %% 设置image信息
  6. if ~exist('image_width','var')||empty(image_width)
  7.     image_width = round(sqrt(n));
  8. end
  9. image_height = round(n/image_width);

  10. %% 设置figure信息
  11. colormap(gray);
  12. figure_rows = floor(sqrt(m));
  13. figure_cols = ceil(sqrt(m));
  14. %% 初始化display_array
  15. pad = 1;
  16. display_array = -ones(pad+(image_height+pad)*figure_rows, ...
  17.     pad+(image_width+pad)*figure_cols);
  18. %% 将X对应赋值给display_array
  19. current_image = 1;

  20. for row = 1:figure_rows
  21.     for col = 1:figure_cols
  22.         max_var = max(X(current_image,:));% 求当前图像的一行中最大值
  23.         display_array(pad+(row-1)*(image_height+pad)+(1:image_height), ...
  24.             pad+(col-1)*(image_width+pad)+(1:image_width))=...
  25.             reshape(X(current_image,:),image_height,image_width)/max_var;
  26.         current_image = current_image+1;
  27.         if current_image > m
  28.             break;
  29.         end
  30.     end
  31.     if current_image >m
  32.         break;
  33.     end
  34. end
  35. %% 画图
  36. figurePane = imagesc(display_array,[-1,1]);
  37. title('Random handwritten digits');
  38. axis image off;
  39. drawnow;
复制代码

  1. function [J,grad] = lrCostFunction(X,y,theta,lamda)
  2. [m,~] = size(X);
  3. h = sigmoid(X*theta);
  4. J = -(y'*log(h)+(1-y')*log(1-h))/m+lamda*(theta'*theta)/(2*m);
  5. grad = X'*(h-y)/m+lamda*[0;theta(2:end)]/m;
  6. end
复制代码


  1. %% 逻辑回归one-Vs-All
  2. function all_theta = oneVsAll(X,y,num_labels,lamda)
  3. %% 初始设置
  4. [m,n] = size(X);
  5. X = [ones(m,1),X];
  6. all_theta = zeros(num_labels,n+1);
  7. options = optimset('GradObj','On','MaxIter',50);

  8. for K = 1:num_labels
  9.     init_theta = zeros(n+1,1);
  10.     costFun = @(the)lrCostFunction(X,(y==K),the,lamda);
  11.     theta = fmincg(costFun,init_theta,options);
  12.     all_theta(K,:) = theta';
  13. end

复制代码

  1. function P = predict(X,theta)
  2. [m,~] = size(X);
  3. X = [ones(m,1),X];
  4. h = sigmoid(X*theta');
  5. [~,P] = max(h,[],2);
  6. end
复制代码

  1. function g = sigmoid(x)
  2. g = 1./(1+exp(-x));
  3. end
复制代码


Ps:(1)就是附件里这个函数MATLAB里面不是自带的,但是一般都会提供给大家使用,这个保存一下就OK了。
(2)这个老登的数据库把所有y的值是0的自动定义成了10,不然他的函数预测部分会出错,我还自己想了半天[0,9]→[1,10]是怎么做的,但事实上,数据库主任提前把这个bug修复了。。
(3)大家感兴趣的可以随便找其他图片进行训练,只不过那个oneVsAll函数部分需要添加一个category,比手写数字这种复杂一丢丢,但是原理一样的啦~

小声逼逼:本来是想一小时独立敲完的,但是虽然差不多一个小时,可是one-Vs-All部分有些函数的输入参量啥的记不大清了,所以最后还是看了大佬给的开源文件,嘤嘤嘤

随机一百个数字(真正随机数)

随机一百个数字(真正随机数)

fmincg.zip

3.15 KB, 下载次数: 0

售价: 1 鱼币  [记录]  [购买]

小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2024-11-6 22:37:56 | 显示全部楼层
收工~回家!
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-7-13 20:24

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表