鱼C论坛

 找回密码
 立即注册
查看: 388|回复: 1

下面感知机代码哪里有错?

[复制链接]
发表于 2018-7-31 16:57:15 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
# -*- coding: utf-8 -*-

import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

sns.set(style="white",color_codes=True)
perc_data=pd.read_csv("perceptron.csv")
print(perc_data)

def sign(v):
    if v > 0:
        return 1
    else:
        return -1
def training():
    weight=[0,0]
    bias=0
    learning_rate=0.5
    train_num=5000

    for i in range(train_num):
        x1=random.choice(perc_data.x_1)
        x2=random.choice(perc_data.x_2)
        y=random.choic(perc_data.label)
        predict=sign(weight[0]*x1+weight[1]*x2+bias)
        if y*predict<=0:
            weight[0]=weight[0]+learning_rate*y*x1
            weight[1]=weight[1]=learning_rate*y*x2
            bias=bias+learning_rate*y
            print(weight[0],weight[1],bias)
    print("stop training: "),
    print(weight[0], weight[1], bias)
            
    return weight,bias

def main():
   
    weight,bias=training()
    g_data=perc_data[perc_data['label']>0]
    r_data=perc_data[perc_data['label']<0]
    plt.plot(g_data['x_1'],g_data['x_2'],'go')
    plt.plot(r_data['x_1'],r_data['x_2'],'ro')
    plt.axis([0,5,-4,15])
    x=np.linspace(0.0,8,100)
    y=-weight[0]/weight[1]*x-bias/weight[1]
    plt.plot(x,y)
    plt.show()
   
if __name__=="__main__":
    main()
  
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

发表于 2018-7-31 20:13:58 | 显示全部楼层
weight[0]=weight[0]+learning_rate*y*x1
weight[1]=weight[1]=learning_rate*y*x2
这里有问题吧
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-10-6 00:31

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表