00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #include "kpttracker.h"
00021 #ifndef KMEANTREE_H
00022 #define KMEANTREE_H
00023
00024 #include <vector>
00025 #include <stdio.h>
00026 #include <map>
00027 #include "sqlite3.h"
00028
00029 namespace kmean_tree {
00030
00031 static const unsigned nb_branches = 4;
00032 #ifdef WITH_SURF
00033 static const unsigned descriptor_size=64;
00034 #endif
00035 #ifdef WITH_PATCH_TAGGER_DESCRIPTOR
00036 #ifdef WITH_PATCH_AS_DESCRIPTOR
00037 static const unsigned descriptor_size=256;
00038 #else
00039 static const unsigned descriptor_size=128;
00040 #endif
00041 #endif
00042 #ifdef WITH_SIFTGPU
00043 static const unsigned descriptor_size=128;
00044 #endif
00045
00046 struct descriptor_t {
00047 float descriptor[descriptor_size];
00048 };
00049 typedef std::vector<descriptor_t *> descriptor_array;
00050
00051 class mean_t {
00052 public:
00053 float mean[descriptor_size];
00054
00055 mean_t();
00056
00057 void accumulate(float a, float b, descriptor_t *d);
00058 double distance(descriptor_t *d);
00059 };
00060
00061 class node_t {
00062 public:
00063
00064 mean_t mean;
00065 unsigned id;
00066
00067 node_t *clusters[nb_branches];
00068
00069 descriptor_array data;
00070
00071 node_t();
00072
00073 bool is_leaf() {
00074 for (unsigned i=0; i<nb_branches; i++)
00075 if (clusters[i]) return false;
00076 return true;
00077 }
00078
00079 unsigned count_leafs() {
00080 unsigned total=0;
00081 for (unsigned i=0; i<nb_branches; i++)
00082 if (clusters[i]) total += clusters[i]->count_leafs();
00083 if (total==0) return 1;
00084 return total;
00085 }
00086
00087 void recursive_split(int max_level, int min_elem, int level=0);
00088 bool run_and_split();
00089
00090 bool save(const char *filename);
00091
00092 unsigned get_id(descriptor_t *d, node_t ** node=0, int depth=0);
00093 void print_summary(int depth=0);
00094 void cmp_scores(std::multimap<double,int> &scores, descriptor_t *data);
00095
00096 bool load(FILE *f);
00097 bool save(FILE *f);
00098
00101 bool save_to_database(const char *fn);
00102
00103 void assign_leaf_ids(unsigned *id_ptr);
00104 protected:
00105 void run_kmean(int nb_iter=32);
00106 bool save(sqlite3 *db, sqlite3_stmt *insert_node, sqlite3_stmt *insert_child);
00107 };
00108
00109
00111 node_t * build_from_data(const char *filename, int max_level, int min_elem, int stop);
00112
00114 node_t *load(sqlite3 *db);
00115
00117 node_t * load(const char *filename);
00118
00119 struct descr_file_packet {
00120 long ptr;
00121 descriptor_t d;
00122 };
00123 };
00124
00125 #endif