|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
import numpy as np
import random
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
m = 1000
def f(x):
# return 8.0 * x + 9.0 + (-1.5 + random.random() * 3)
# return 8.0 * x + 9.0 + random.uniform(-1.5,1.5)
return 8.0*x+9.0+(-1000+random.random()*2000)
def E(a, b, P):
sum = 0
for p in P:
sum = sum + (p[0]*a+b-p[1])**2
return sum / len(P)
def d_f_A(a, b, P):
sum = 0
for p in P:
sum = sum + 2*p[0]*(a*p[0]+b-p[1])
return sum / len(P)
def d_f_B(a, b, P):
sum = 0
for p in P:
sum = sum + 2*(b+a*p[0]-p[1])
return sum / len(P)
P = []
X = []
Y = []
for i in range(1, m):
x = -1000 + random.random() * 2000
X.append(x)
y = f(x)
Y.append(y)
P.append([x, y])
learning_rate = 0.0000001
max_loop = 1000
tolerance = 0.01
a_init = random.random()*2
a = a_init
b_init = 8+random.random()*2
b = b_init
GDX = [a]
GDY = [b]
GDZ = [E(a, b, P)]
E_pre = 0
for i in range(max_loop):
d_f_a = d_f_A(a, b, P)
d_f_b = d_f_B(a, b, P)
a = a-learning_rate * d_f_a
b = b-learning_rate * d_f_b
GDX.append(a)
GDY.append(b)
E_cur = E(a, b, P)
GDZ.append(E_cur)
# print(x, y)
if abs(E_cur-E_pre)<tolerance:
break
E_pre = E_cur
print('a的初值为 =', a_init)
print('b的初值为 =', b_init)
print('拟合后 a=', a, 'b=', b)
print('E(a,b) =', E(a, b, P))
plt.scatter(X, Y)
xe = np.arange(-1000,1000)
plt.plot(xe,a*xe+b)
plt.show()
|
|