鱼C论坛

 找回密码
 立即注册
查看: 2245|回复: 0

[技术交流] 机器学习系列------Swiss Roll数据库

[复制链接]
发表于 2018-6-24 09:43:41 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
        Swiss Roll数据库形状是一个三维的蛋糕卷,在这里我们学习降维,给降到两维,首先读取swiss roll数据,然后用kernelPCA算法降维:
  1. from sklearn.decomposition import KernelPCA
  2. from sklearn import datasets

  3. X, color = datasets.samples_generator.make_swiss_roll(n_samples=1500)
  4. rbf_pca=KernelPCA(n_components=2,kernel="rbf",gamma=0.04,fit_inverse_transform=True)
  5. X_reduced=rbf_pca.fit_transform(X)
  6. X_preimage=rbf_pca.inverse_transform(X_reduced)
复制代码

        以上代码里的X_reduced变量就是降成两维的数据,X_preimage是又给还原成3维的数据,接着我们看看还原后和原始数据的误差:
  1. from sklearn.metrics import mean_squared_error

  2. mean_squared_error(X,X_preimage)
复制代码

        输出为:28.43337304287336。还有一种降维的算法如下:
  1. from sklearn.manifold import LocallyLinearEmbedding

  2. lle=LocallyLinearEmbedding(n_components=2,n_neighbors=10)
  3. X_reduced=lle.fit_transform(X)
复制代码

        这种LLE算法跟knn近邻算法有点类似,原理是找低维度的最近的10个点代替,非常的简单高效。这样我们就把复杂的三维图像变成二维了,我们画个简单的图看看:
  1. %matplotlib inline
  2. import matplotlib
  3. import matplotlib.pyplot as plt

  4. y=color.astype("int32")
  5. plt.scatter(X_reduced[y>=10,0],X_reduced[y>=10,1],color="red",alpha=0.2)
  6. plt.scatter(X_reduced[y<10,0],X_reduced[y<10,1],color="blue",alpha=0.2)
复制代码

        输出图像为:
sadasdsad.png

本帖被以下淘专辑推荐:

想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-3-29 00:35

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表