|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
本帖最后由 峥途 于 2024-10-20 16:34 编辑
机器学习——线性回归算法
今天复习了线性回归算法,自己独立完成了线性回归的实验,但是出了好几次bug,在AI的辅助下才最终成功修改代码为。
如下是我的线性回归代码和所用数据库。
PS:
(1)在这个里面,新研究了函数“normalize”,感觉方便又快捷。
(2)这个代价函数最终并不是收敛到0的,哪怕我归一化后它也不是收敛到0的,经过查找资料,发现J在数据不好的时候收敛不到0是正常的。最后收敛到4.7左右。
%% 线性回归
%% 清理数据
close;clear all;clc;
%% 录入数据(画图)
data = load('ex1data1.txt');
x = data(:,1);
y = data(:,2);
m = size(x,1);
% 绘制数据散点图
figure(1);
plot(x,y,'rx','MarkerSize',10);
xlabel ('x');
ylabel ('y');
hold on;
%% 求解最终梯度(并输出&绘制拟合图像) (包含代价函数)
% 梯度下降
% 初始参数
x_norm = normalize(x,'zscore');
x_norm = [ones(m,1),x_norm];
fprintf('x_norm is :%f\n',x_norm(:,2));
theta = zeros(size(x_norm,2),1);
iterations = 3000;
alpha = 0.01;
% 梯度下降
[J,theta] = gradientDescent(x_norm,y,theta,alpha,iterations);
fprintf('The final theta is: %f\n',theta);
fprintf('The final J is: %f\n',J(iterations));
% 梯度变化图
figure(2);
plot(1:iterations,J,'b-','LineWidth',4);
% 线性回归直线
figure(1);
hold on;
plot(x,x_norm*theta,'b-');
title('Linear Regression');
legend('real','predict');
hold off;
%% 绘制代价函数的三维图像以及梯度图
theta1_vals = linspace(1,10,100);
theta2_vals = linspace(-1,9,100);
J_vals = zeros(length(theta1_vals),length(theta2_vals));
for i = 1:length(theta2_vals)
for j = 1:length(theta2_vals)
t = [theta1_vals(i);theta2_vals(j)];
J_vals(j,i) = computeCost(x_norm,y,t);
end
end
% 绘制三维J图像
J_vals = J_vals';
figure(3);
surf(theta1_vals,theta2_vals,J_vals);
xlabel('x');ylabel('y');zlabel('z');
title('J');
% 绘制梯形图
figure(4);
contour(theta1_vals,theta2_vals,J_vals);
xlabel('\theta_0'); ylabel('\theta_1');
hold on;
plot(theta(1), theta(2), 'rx', 'MarkerSize', 10, 'LineWidth', 2);
function [J,theta] = gradientDescent(x,y,theta,alpha,iterations)
J = zeros(iterations,1);
m = size(x,1);
% 归一化
% x_norm = normalize(x,'zscore');
for i = 1:iterations
theta = theta - alpha/m*x'*(x*theta-y);
J(i) = computeCost(x,y,theta);
end
end
function J = computeCost(x,y,theta)
m = length(y);
J = sum((x*theta-y).^2)/(2*m);
end
数据:
6.1101,17.592
5.5277,9.1302
8.5186,13.662
7.0032,11.854
5.8598,6.8233
8.3829,11.886
7.4764,4.3483
8.5781,12
6.4862,6.5987
5.0546,3.8166
5.7107,3.2522
14.164,15.505
5.734,3.1551
8.4084,7.2258
5.6407,0.71618
5.3794,3.5129
6.3654,5.3048
5.1301,0.56077
6.4296,3.6518
7.0708,5.3893
6.1891,3.1386
20.27,21.767
5.4901,4.263
6.3261,5.1875
5.5649,3.0825
18.945,22.638
12.828,13.501
10.957,7.0467
13.176,14.692
22.203,24.147
5.2524,-1.22
6.5894,5.9966
9.2482,12.134
5.8918,1.8495
8.2111,6.5426
7.9334,4.5623
8.0959,4.1164
5.6063,3.3928
12.836,10.117
6.3534,5.4974
5.4069,0.55657
6.8825,3.9115
11.708,5.3854
5.7737,2.4406
7.8247,6.7318
7.0931,1.0463
5.0702,5.1337
5.8014,1.844
11.7,8.0043
5.5416,1.0179
7.5402,6.7504
5.3077,1.8396
7.4239,4.2885
7.6031,4.9981
6.3328,1.4233
6.3589,-1.4211
6.2742,2.4756
5.6397,4.6042
9.3102,3.9624
9.4536,5.4141
8.8254,5.1694
5.1793,-0.74279
21.279,17.929
14.908,12.054
18.959,17.054
7.2182,4.8852
8.2951,5.7442
10.236,7.7754
5.4994,1.0173
20.341,20.992
10.136,6.6799
7.3345,4.0259
6.0062,1.2784
7.2259,3.3411
5.0269,-2.6807
6.5479,0.29678
7.5386,3.8845
5.0365,5.7014
10.274,6.7526
5.1077,2.0576
5.7292,0.47953
5.1884,0.20421
6.3557,0.67861
9.7687,7.5435
6.5159,5.3436
8.5172,4.2415
9.1802,6.7981
6.002,0.92695
5.5204,0.152
5.0594,2.8214
5.7077,1.8451
7.6366,4.2959
5.8707,7.2029
5.3054,1.9869
8.2934,0.14454
13.394,9.0551
5.4369,0.61705
|
评分
-
查看全部评分
|