鱼C论坛

 找回密码
 立即注册
查看: 75|回复: 0

[经验总结] 机器学习——线性回归算法

[复制链接]
发表于 3 天前 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 峥途 于 2024-10-20 16:34 编辑

机器学习——线性回归算法

今天复习了线性回归算法,自己独立完成了线性回归的实验,但是出了好几次bug,在AI的辅助下才最终成功修改代码为。
如下是我的线性回归代码和所用数据库。
PS:
(1)在这个里面,新研究了函数“normalize”,感觉方便又快捷。
(2)这个代价函数最终并不是收敛到0的,哪怕我归一化后它也不是收敛到0的,经过查找资料,发现J在数据不好的时候收敛不到0是正常的。最后收敛到4.7左右。
%% 线性回归
%% 清理数据
close;clear all;clc;
%% 录入数据(画图)
data = load('ex1data1.txt');
x = data(:,1);
y = data(:,2);
m = size(x,1);

% 绘制数据散点图
figure(1);
plot(x,y,'rx','MarkerSize',10);
xlabel ('x');
ylabel ('y');
hold on;

%% 求解最终梯度(并输出&绘制拟合图像) (包含代价函数)
% 梯度下降
% 初始参数
x_norm = normalize(x,'zscore');
x_norm = [ones(m,1),x_norm];
fprintf('x_norm is :%f\n',x_norm(:,2));
theta = zeros(size(x_norm,2),1);
iterations = 3000;
alpha = 0.01;

% 梯度下降
[J,theta] = gradientDescent(x_norm,y,theta,alpha,iterations);
fprintf('The final theta is: %f\n',theta);
fprintf('The final J is: %f\n',J(iterations));

% 梯度变化图
figure(2);
plot(1:iterations,J,'b-','LineWidth',4);

% 线性回归直线
figure(1);
hold on;
plot(x,x_norm*theta,'b-');
title('Linear Regression');
legend('real','predict');
hold off;

%% 绘制代价函数的三维图像以及梯度图
theta1_vals = linspace(1,10,100);
theta2_vals = linspace(-1,9,100);

J_vals = zeros(length(theta1_vals),length(theta2_vals));

for i = 1:length(theta2_vals)
    for j = 1:length(theta2_vals)
        t = [theta1_vals(i);theta2_vals(j)];
        J_vals(j,i) = computeCost(x_norm,y,t);
    end
end

% 绘制三维J图像
J_vals = J_vals';
figure(3);
surf(theta1_vals,theta2_vals,J_vals);
xlabel('x');ylabel('y');zlabel('z');
title('J');

% 绘制梯形图
figure(4);
contour(theta1_vals,theta2_vals,J_vals);
xlabel('\theta_0'); ylabel('\theta_1');
hold on;
plot(theta(1), theta(2), 'rx', 'MarkerSize', 10, 'LineWidth', 2);

function [J,theta] = gradientDescent(x,y,theta,alpha,iterations)
J = zeros(iterations,1);
m = size(x,1);
% 归一化
% x_norm = normalize(x,'zscore');
for i = 1:iterations
    theta = theta - alpha/m*x'*(x*theta-y);
    J(i) = computeCost(x,y,theta);
end

end


function J = computeCost(x,y,theta)
m = length(y);
J = sum((x*theta-y).^2)/(2*m);
end

数据:
6.1101,17.592
5.5277,9.1302
8.5186,13.662
7.0032,11.854
5.8598,6.8233
8.3829,11.886
7.4764,4.3483
8.5781,12
6.4862,6.5987
5.0546,3.8166
5.7107,3.2522
14.164,15.505
5.734,3.1551
8.4084,7.2258
5.6407,0.71618
5.3794,3.5129
6.3654,5.3048
5.1301,0.56077
6.4296,3.6518
7.0708,5.3893
6.1891,3.1386
20.27,21.767
5.4901,4.263
6.3261,5.1875
5.5649,3.0825
18.945,22.638
12.828,13.501
10.957,7.0467
13.176,14.692
22.203,24.147
5.2524,-1.22
6.5894,5.9966
9.2482,12.134
5.8918,1.8495
8.2111,6.5426
7.9334,4.5623
8.0959,4.1164
5.6063,3.3928
12.836,10.117
6.3534,5.4974
5.4069,0.55657
6.8825,3.9115
11.708,5.3854
5.7737,2.4406
7.8247,6.7318
7.0931,1.0463
5.0702,5.1337
5.8014,1.844
11.7,8.0043
5.5416,1.0179
7.5402,6.7504
5.3077,1.8396
7.4239,4.2885
7.6031,4.9981
6.3328,1.4233
6.3589,-1.4211
6.2742,2.4756
5.6397,4.6042
9.3102,3.9624
9.4536,5.4141
8.8254,5.1694
5.1793,-0.74279
21.279,17.929
14.908,12.054
18.959,17.054
7.2182,4.8852
8.2951,5.7442
10.236,7.7754
5.4994,1.0173
20.341,20.992
10.136,6.6799
7.3345,4.0259
6.0062,1.2784
7.2259,3.3411
5.0269,-2.6807
6.5479,0.29678
7.5386,3.8845
5.0365,5.7014
10.274,6.7526
5.1077,2.0576
5.7292,0.47953
5.1884,0.20421
6.3557,0.67861
9.7687,7.5435
6.5159,5.3436
8.5172,4.2415
9.1802,6.7981
6.002,0.92695
5.5204,0.152
5.0594,2.8214
5.7077,1.8451
7.6366,4.2959
5.8707,7.2029
5.3054,1.9869
8.2934,0.14454
13.394,9.0551
5.4369,0.61705

评分

参与人数 1鱼币 +5 收起 理由
某一个“天” + 5 鱼C有你更精彩^_^

查看全部评分

本帖被以下淘专辑推荐:

想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-10-23 09:33

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表