00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024 #if defined(_MSC_VER)
00025 #pragma once
00026 #endif
00027
00028 #ifndef PBRT_CORE_KDTREE_H
00029 #define PBRT_CORE_KDTREE_H
00030
00031
00032 #include "pbrt.h"
00033 #include "geometry.h"
00034
00035
00036 struct KdNode {
00037 void init(float p, uint32_t a) {
00038 splitPos = p;
00039 splitAxis = a;
00040 rightChild = (1<<29)-1;
00041 hasLeftChild = 0;
00042 }
00043 void initLeaf() {
00044 splitAxis = 3;
00045 rightChild = (1<<29)-1;
00046 hasLeftChild = 0;
00047 }
00048
00049 float splitPos;
00050 uint32_t splitAxis:2;
00051 uint32_t hasLeftChild:1, rightChild:29;
00052 };
00053
00054
00055 template <typename NodeData> class KdTree {
00056 public:
00057
00058 KdTree(const vector<NodeData> &data);
00059 ~KdTree() {
00060 FreeAligned(nodes);
00061 FreeAligned(nodeData);
00062 }
00063 template <typename LookupProc> void Lookup(const Point &p,
00064 LookupProc &process, float &maxDistSquared) const;
00065 private:
00066
00067 void recursiveBuild(uint32_t nodeNum, int start, int end,
00068 const NodeData **buildNodes);
00069 template <typename LookupProc> void privateLookup(uint32_t nodeNum,
00070 const Point &p, LookupProc &process, float &maxDistSquared) const;
00071
00072
00073 KdNode *nodes;
00074 NodeData *nodeData;
00075 uint32_t nNodes, nextFreeNode;
00076 };
00077
00078
00079 template <typename NodeData> struct CompareNode {
00080 CompareNode(int a) { axis = a; }
00081 int axis;
00082 bool operator()(const NodeData *d1, const NodeData *d2) const {
00083 return d1->p[axis] == d2->p[axis] ? (d1 < d2) :
00084 d1->p[axis] < d2->p[axis];
00085 }
00086 };
00087
00088
00089
00090
00091 template <typename NodeData>
00092 KdTree<NodeData>::KdTree(const vector<NodeData> &d) {
00093 nNodes = d.size();
00094 nextFreeNode = 1;
00095 nodes = AllocAligned<KdNode>(nNodes);
00096 nodeData = AllocAligned<NodeData>(nNodes);
00097 vector<const NodeData *> buildNodes(nNodes, NULL);
00098 for (uint32_t i = 0; i < nNodes; ++i)
00099 buildNodes[i] = &d[i];
00100
00101 recursiveBuild(0, 0, nNodes, &buildNodes[0]);
00102 }
00103
00104
00105 template <typename NodeData> void
00106 KdTree<NodeData>::recursiveBuild(uint32_t nodeNum, int start, int end,
00107 const NodeData **buildNodes) {
00108
00109 if (start + 1 == end) {
00110 nodes[nodeNum].initLeaf();
00111 nodeData[nodeNum] = *buildNodes[start];
00112 return;
00113 }
00114
00115
00116
00117
00118 BBox bound;
00119 for (int i = start; i < end; ++i)
00120 bound = Union(bound, buildNodes[i]->p);
00121 int splitAxis = bound.MaximumExtent();
00122 int splitPos = (start+end)/2;
00123 std::nth_element(&buildNodes[start], &buildNodes[splitPos],
00124 &buildNodes[end], CompareNode<NodeData>(splitAxis));
00125
00126
00127 nodes[nodeNum].init(buildNodes[splitPos]->p[splitAxis], splitAxis);
00128 nodeData[nodeNum] = *buildNodes[splitPos];
00129 if (start < splitPos) {
00130 nodes[nodeNum].hasLeftChild = 1;
00131 uint32_t childNum = nextFreeNode++;
00132 recursiveBuild(childNum, start, splitPos, buildNodes);
00133 }
00134 if (splitPos+1 < end) {
00135 nodes[nodeNum].rightChild = nextFreeNode++;
00136 recursiveBuild(nodes[nodeNum].rightChild, splitPos+1,
00137 end, buildNodes);
00138 }
00139 }
00140
00141
00142 template <typename NodeData> template <typename LookupProc>
00143 void KdTree<NodeData>::Lookup(const Point &p, LookupProc &proc,
00144 float &maxDistSquared) const {
00145 privateLookup(0, p, proc, maxDistSquared);
00146 }
00147
00148
00149 template <typename NodeData> template <typename LookupProc>
00150 void KdTree<NodeData>::privateLookup(uint32_t nodeNum, const Point &p,
00151 LookupProc &process, float &maxDistSquared) const {
00152 KdNode *node = &nodes[nodeNum];
00153
00154 int axis = node->splitAxis;
00155 if (axis != 3) {
00156 float dist2 = (p[axis] - node->splitPos) * (p[axis] - node->splitPos);
00157 if (p[axis] <= node->splitPos) {
00158 if (node->hasLeftChild)
00159 privateLookup(nodeNum+1, p, process, maxDistSquared);
00160 if (dist2 < maxDistSquared && node->rightChild < nNodes)
00161 privateLookup(node->rightChild, p, process, maxDistSquared);
00162 }
00163 else {
00164 if (node->rightChild < nNodes)
00165 privateLookup(node->rightChild, p, process, maxDistSquared);
00166 if (dist2 < maxDistSquared && node->hasLeftChild)
00167 privateLookup(nodeNum+1, p, process, maxDistSquared);
00168 }
00169 }
00170
00171
00172 float dist2 = DistanceSquared(nodeData[nodeNum].p, p);
00173 if (dist2 < maxDistSquared)
00174 process(p, nodeData[nodeNum], dist2, maxDistSquared);
00175 }
00176
00177
00178
00179 #endif // PBRT_CORE_KDTREE_H