00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024 #ifndef PBRT_KDTREE_H
00025 #define PBRT_KDTREE_H
00026
00027 #include "pbrt.h"
00028 #include "geometry.h"
00029
00030 struct KdNode {
00031 void init(float p, u_int a) {
00032 splitPos = p;
00033 splitAxis = a;
00034 rightChild = ~(0U);
00035 hasLeftChild = 0;
00036 }
00037 void initLeaf() {
00038 splitAxis = 3;
00039 rightChild = ~(0U);
00040 hasLeftChild = 0;
00041 }
00042
00043 float splitPos;
00044 u_int splitAxis:2;
00045 u_int hasLeftChild:1;
00046 u_int rightChild:29;
00047 };
00048 template <class NodeData, class LookupProc> class KdTree {
00049 public:
00050
00051 KdTree(const vector<NodeData> &data);
00052 ~KdTree() {
00053 FreeAligned(nodes);
00054 delete[] nodeData;
00055 }
00056 void recursiveBuild(u_int nodeNum, int start, int end,
00057 vector<const NodeData *> &buildNodes);
00058 void Lookup(const Point &p, const LookupProc &process,
00059 float &maxDistSquared) const;
00060 private:
00061
00062 void privateLookup(u_int nodeNum, const Point &p,
00063 const LookupProc &process, float &maxDistSquared) const;
00064
00065 KdNode *nodes;
00066 NodeData *nodeData;
00067 u_int nNodes, nextFreeNode;
00068 };
00069 template<class NodeData> struct CompareNode {
00070 CompareNode(int a) { axis = a; }
00071 int axis;
00072 bool operator()(const NodeData *d1,
00073 const NodeData *d2) const {
00074 return d1->p[axis] == d2->p[axis] ? (d1 < d2) :
00075 d1->p[axis] < d2->p[axis];
00076 }
00077 };
00078
00079 template <class NodeData, class LookupProc>
00080 KdTree<NodeData,
00081 LookupProc>::KdTree(const vector<NodeData> &d) {
00082 nNodes = d.size();
00083 nextFreeNode = 1;
00084 nodes = (KdNode *)AllocAligned(nNodes * sizeof(KdNode));
00085 nodeData = new NodeData[nNodes];
00086 vector<const NodeData *> buildNodes;
00087 for (u_int i = 0; i < nNodes; ++i)
00088 buildNodes.push_back(&d[i]);
00089
00090 recursiveBuild(0, 0, nNodes, buildNodes);
00091 }
00092 template <class NodeData, class LookupProc> void
00093 KdTree<NodeData, LookupProc>::recursiveBuild(u_int nodeNum,
00094 int start, int end,
00095 vector<const NodeData *> &buildNodes) {
00096
00097 if (start + 1 == end) {
00098 nodes[nodeNum].initLeaf();
00099 nodeData[nodeNum] = *buildNodes[start];
00100 return;
00101 }
00102
00103
00104 BBox bound;
00105 for (int i = start; i < end; ++i)
00106 bound = Union(bound, buildNodes[i]->p);
00107 int splitAxis = bound.MaximumExtent();
00108 int splitPos = (start+end)/2;
00109 std::nth_element(&buildNodes[start], &buildNodes[splitPos],
00110 &buildNodes[end-1]+1, CompareNode<NodeData>(splitAxis));
00111
00112 nodes[nodeNum].init(buildNodes[splitPos]->p[splitAxis],
00113 splitAxis);
00114 nodeData[nodeNum] = *buildNodes[splitPos];
00115 if (start < splitPos) {
00116 nodes[nodeNum].hasLeftChild = 1;
00117 u_int childNum = nextFreeNode++;
00118 recursiveBuild(childNum, start, splitPos, buildNodes);
00119 }
00120 if (splitPos+1 < end) {
00121 nodes[nodeNum].rightChild = nextFreeNode++;
00122 recursiveBuild(nodes[nodeNum].rightChild, splitPos+1,
00123 end, buildNodes);
00124 }
00125 }
00126 template <class NodeData, class LookupProc> void
00127 KdTree<NodeData, LookupProc>::Lookup(const Point &p,
00128 const LookupProc &proc,
00129 float &maxDistSquared) const {
00130 privateLookup(0, p, proc, maxDistSquared);
00131 }
00132 template <class NodeData, class LookupProc> void
00133 KdTree<NodeData, LookupProc>::privateLookup(u_int nodeNum,
00134 const Point &p, const LookupProc &process,
00135 float &maxDistSquared) const {
00136 KdNode *node = &nodes[nodeNum];
00137
00138 int axis = node->splitAxis;
00139 if (axis != 3) {
00140 float dist2 = (p[axis] - node->splitPos) *
00141 (p[axis] - node->splitPos);
00142 if (p[axis] <= node->splitPos) {
00143 if (node->hasLeftChild)
00144 privateLookup(nodeNum+1, p,
00145 process, maxDistSquared);
00146 if (dist2 < maxDistSquared &&
00147 node->rightChild < nNodes)
00148 privateLookup(node->rightChild,
00149 p,
00150 process,
00151 maxDistSquared);
00152 }
00153 else {
00154 if (node->rightChild < nNodes)
00155 privateLookup(node->rightChild,
00156 p,
00157 process,
00158 maxDistSquared);
00159 if (dist2 < maxDistSquared && node->hasLeftChild)
00160 privateLookup(nodeNum+1,
00161 p,
00162 process,
00163 maxDistSquared);
00164 }
00165 }
00166
00167 float dist2 = DistanceSquared(nodeData[nodeNum].p, p);
00168 if (dist2 < maxDistSquared)
00169 process(nodeData[nodeNum], dist2, maxDistSquared);
00170 }
00171 #endif // PBRT_KDTREE_H