鱼C论坛

 找回密码
 立即注册
查看: 780|回复: 1

python中对zip函数进行计算以及kd树搜索

[复制链接]
发表于 2019-2-23 10:43:39 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
首先对zip()后的结果进行计算
from math import sqrt
p = [1,2]
t = [2, 3]
a = sqrt(sum(p1-p2)**2 for p1, p2 in zip(p, t))
结果报错:
Traceback (most recent call last):
  File "<input>", line 1, in <module>
TypeError: must be real number, not generator
但是我能够打印出结果:
for x, y in zip(p, t):
    print(x)
    print(y)
结果:
1
2
2
3

还有就是kd树以及搜索最近邻点
from math import sqrt
from collections import namedtuple

class KdNode(object):
    def __init__(self, dom_elt, split, left, right):
        self.dom_elt = dom_elt
        self.split = split
        self.left = left
        self.right = right

class KdTree(object):
    def __init__(self, data):
        k = len(data[0])

        def CreateNode(split, data_set):
            if not data_set:
                return None
            data_set.sort(key = lambda x:x[split])
            split_pos = len(data_set) // 2
            median = data_set[split_pos]
            split_next = (split + 1) % k

            return KdNode(median, split,\
                          CreateNode(split_next, data_set[:split_pos]),\
                    CreateNode(split_next, data_set[split_pos + 1 :]))

        self.root = CreateNode(0, data)

def preorder(root):
    print(root.dom_elt)
    if root.left:
        preorder(root.left)
    if root.right:
        preorder(root.right)

result = namedtuple('Result_tuple', 'nearest_point nearest_dist nearest_visited')
def find_nearest(tree, point):
    k = len(point)

    def trave(kd_node, target, max_dist):
        if kd_node is None:
            return result([0]*k, float('inf'), 0)

        node_visited = 1

        s = kd_node.split
        pivot = kd_node.dom_elt

        if target[s] <= pivot[s]:
            nearest_node = kd_node.left
            further_node = kd_node.right
        else:
            nearest_node = kd_node.right
            further_node = kd_node.left

        temp1 = trave(nearest_node, target, max_dist)

        nearest = temp1.nearest_point
        dist = temp1.nearest_dist

        node_visited += temp1.nearest_visited

        if dist < max_dist:
            max_dist = dist

        temp_dist = abs(pivot[s] - target[s])
        if max_dist < temp_dist:
            return result(nearest, dist, node_visited)

        temp_dist = sqrt(sum(p1-p2)**2 for p1, p2 in list(zip(pivot, target)))

        if temp_dist < dist:
            nearest = pivot
            dist = temp_dist
            max_dist = dist

        temp2 = trave(further_node, target, max_dist)
        node_visited +=temp2.nearest_visited
        if temp2.nearest_dist < dist:
            nearest = temp2.nearest_point
            dist = temp2.nearest_dist

        return result(nearest, dist, node_visited)

    return trave(tree.root, point, float('inf'))


data = [[2,3], [5,4], [9,6], [4,7], [8,1], [7,2]]
kd = KdTree(data)
find_nearest(kd, [3,4.5])

find_nearest()那一块里总有一些不对。。。又说不上来
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2019-2-23 11:03:18 | 显示全部楼层
改了一下代码。。。能跑通了
不能直接对生成器求和。
from math import sqrt
p = [1,2]
t = [2, 3]
a = sqrt(sum(list((p1-p2)**2 for p1, p2 in zip(p, t))))
这样就能跑通了。。kd树也跑通了
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-4-18 11:46

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表