00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #ifdef WIN32
00021 #define _CRT_SECURE_NO_DEPRECATE
00022 #endif
00023
00024 #include <math.h>
00025 #include <assert.h>
00026 #include "kmeantree.h"
00027 #ifdef NDEBUG
00028 #undef NDEBUG
00029 #endif
00030
00031 #include <string.h>
00032 #include <iostream>
00033 #include <math.h>
00034 #include <stdlib.h>
00035 #include <map>
00036 #ifdef _OPENMP
00037 #include <omp.h>
00038 #endif
00039
00040 using namespace kmean_tree;
00041 using namespace std;
00042
00043 #ifdef WIN32
00044 #include <float.h>
00045 static inline double drand48() {
00046 return (double)rand()/(double)RAND_MAX;
00047 }
00048
00049 static inline int finite(float f) {
00050 return _finite(f);
00051 }
00052 #endif
00053
00054 mean_t::mean_t() {
00055 memset(mean,0,descriptor_size*sizeof(float));
00056 }
00057
00058 node_t::node_t() {
00059 for (unsigned i=0; i<nb_branches; i++)
00060 clusters[i] = 0;
00061 }
00062
00063 #ifdef _OPENMP
00064 omp_lock_t lock;
00065 int thread_count=1;
00066 #endif
00067
00068 void node_t::recursive_split(int max_level, int min_elem, int level) {
00069 if (level >= max_level) {
00070 std::cout << "stopping splitting at depth " << level << ", with "
00071 << data.size() << " elements\n";
00072 return;
00073 }
00074
00075 if (is_leaf()) {
00076 if (data.size()>(unsigned)min_elem) {
00077 run_and_split();
00078 } else {
00079 std::cout << "stopping splitting at depth " << level << ", with "
00080 << data.size() << " elements\n";
00081 }
00082 }
00083 #ifdef _OPENMP
00084 omp_set_lock(&lock);
00085 int thread_add=0;
00086 if(thread_count<8) {
00087 omp_set_nested(1);
00088 thread_add=3;
00089 thread_count+=thread_add;
00090 } else {
00091 omp_set_nested(0);
00092 }
00093 omp_unset_lock(&lock);
00094 #pragma omp parallel for
00095 #endif
00096 for (int i=0; (unsigned)i<nb_branches; i++)
00097 if (clusters[i])
00098 clusters[i]->recursive_split(max_level, min_elem, level+1);
00099 #ifdef _OPENMP
00100 omp_set_lock(&lock);
00101 thread_count-=thread_add;
00102 omp_unset_lock(&lock);
00103 #endif
00104 }
00105
00106 unsigned node_t::get_id(descriptor_t *descr, node_t **node, int depth)
00107 {
00108 if (is_leaf()) {
00109 if (node) *node = this;
00110 return id;
00111 }
00112
00113
00114 std::multimap<double,int> scores;
00115 cmp_scores(scores, descr);
00116 std::multimap<double,int>::iterator it = scores.begin();
00117
00118 return clusters[it->second]->get_id(descr, node, depth+1);
00119 }
00120
00121 bool node_t::run_and_split() {
00122 assert(is_leaf());
00123
00124
00125 for (unsigned i=0; i<nb_branches; i++) {
00126 clusters[i] = new node_t;
00127 }
00128
00129 run_kmean();
00130
00131
00132 data.clear();
00133 return true;
00134 }
00135
00136 void node_t::assign_leaf_ids(unsigned *id_ptr)
00137 {
00138 unsigned *ptr=id_ptr;
00139 if (id_ptr==0) {
00140 ptr=&id;
00141 }
00142 id=0;
00143 if (is_leaf())
00144 id = ++(*ptr);
00145 else {
00146 for (unsigned i=0; i<nb_branches; i++)
00147 if (clusters[i])
00148 clusters[i]->assign_leaf_ids(ptr);
00149 }
00150 }
00151
00152 void node_t::cmp_scores(std::multimap<double,int> &scores, descriptor_t *data)
00153 {
00154 scores.clear();
00155 for (unsigned i=0; i<nb_branches; i++) {
00156 if (clusters[i]) {
00157 double d = clusters[i]->mean.distance(data);
00158
00159 scores.insert(std::pair<double,int>(d,i));
00160 }
00161 }
00162 }
00163
00164
00165 void node_t::run_kmean(int nb_iter)
00166 {
00167 bool online_kmean = false;
00168 int n = data.size();
00169 int k = nb_branches;
00170
00171 assert(n>k);
00172 assert(k>=2);
00173
00174 std::vector<float> counter(k, 0);
00175 std::vector<mean_t> new_mean(k);
00176
00177 std::cout << "(starting k-mean:" ;
00178 std::cout << " total: "<< n << ")" << std::endl;
00179
00180
00181 for (int i=0; i<k; i++) {
00182
00183 int nb_init=1;
00184 unsigned d = drand48()*n;
00185 clusters[i]->mean.accumulate(0, 1.0f/nb_init, data[d]);
00186 for (int k=1; k<nb_init; k++) {
00187 d = drand48()*n;
00188 clusters[i]->mean.accumulate(1, 1.0f/nb_init, data[d]);
00189 }
00190 }
00191
00192 for (int iter=0;iter<nb_iter;iter++) {
00193
00194
00195 for (int i=0; i<k; i++) {
00196 counter[i]=0;
00197 }
00198
00199
00200 for (int j=0; j<n; j++) {
00201
00202
00203 std::multimap<double,int> scores;
00204 cmp_scores(scores, data[j]);
00205
00206 std::multimap<double,int>::iterator b = scores.begin();
00207 std::multimap<double,int>::iterator b2 = b;
00208 ++b2;
00209
00210 int best_m = b->second;
00211
00212 if (iter==nb_iter-1) {
00213
00214 clusters[best_m]->data.push_back(data[j]);
00215 if (0 && b2->first/b->first > .98) {
00216
00217 clusters[b2->second]->data.push_back(data[j]);
00218 }
00219 }
00220
00221
00222 counter[best_m] += 1;
00223 if (online_kmean) {
00224 clusters[best_m]->mean.accumulate(.95, .05, data[j]);
00225 } else {
00226 new_mean[best_m].accumulate(
00227 (counter[best_m]-1.0f)/counter[best_m],
00228 1.0f/counter[best_m],
00229 data[j]);
00230 }
00231 }
00232
00233
00234
00235
00236
00237
00238
00239
00240
00241
00242 if (!online_kmean)
00243 for (int i=0; i<k; i++) {
00244 for (unsigned j=0; j<descriptor_size; j++)
00245 clusters[i]->mean.mean[j] = new_mean[i].mean[j];
00246 }
00247 }
00248 for (int i=0; i<k; i++) {
00249 if (counter[i]==0) {
00250 if (clusters[i]->data.size()>0) {
00251 std::cout << "Warning, killing a cluster with " << clusters[i]->data.size()
00252 << " elements in it!\n";
00253 }
00254 delete clusters[i];
00255 clusters[i]=0;
00256 }
00257 else {
00258 std::cout << " C" << i << ": " << clusters[i]->data.size();
00259
00260 }
00261 }
00262 std::cout << std::endl;
00263 }
00264
00265 void mean_t::accumulate(float a, float b, descriptor_t *d)
00266 {
00267 if (a==0) {
00268 for (unsigned i=0; i<descriptor_size; i++) {
00269 mean[i] = b*d->descriptor[i];
00270
00271 }
00272 } else {
00273 for (unsigned i=0; i<descriptor_size; i++) {
00274 mean[i] = mean[i]*a + b*d->descriptor[i];
00275
00276 }
00277 }
00278
00279
00280
00281
00282
00283
00284
00285
00286
00287
00288
00289 }
00290
00291 double mean_t::distance(descriptor_t *d) {
00292 double dist=0;
00293 for (unsigned i=0; i<descriptor_size; i++) {
00294
00295
00296
00297 double di = mean[i]-(float)d->descriptor[i];
00298 dist += di*di;
00299 }
00300 assert(finite(dist));
00301 return dist;
00302 }
00303
00304 bool node_t::save(const char *filename)
00305 {
00306
00307
00308
00309
00310
00311
00312
00313
00314
00315
00316
00317 FILE *f = fopen(filename,"wb");
00318 if (!f) {
00319 perror(filename);
00320 return false;
00321 }
00322 save(f);
00323 fclose(f);
00324 return true;
00325 }
00326
00327 bool node_t::save_to_database(const char *fn)
00328 {
00329 char *errmsg=0;
00330 sqlite3 *sql3;
00331 int rc = sqlite3_open(fn, &sql3);
00332 if (rc) {
00333 cerr << "Can't open database: " << sqlite3_errmsg(sql3) << endl;
00334 sqlite3_close(sql3);
00335 return false;
00336 }
00337
00338
00339 rc=sqlite3_exec(sql3,
00340 "create table if not exists tree_nodes ("
00341 "ptr integer primary key,"
00342 "id integer,"
00343 "mean blob);\n"
00344 "create table if not exists tree_structure ("
00345 "parent integer,"
00346 "child integer);\n"
00347 "delete from tree_nodes;\n"
00348 "delete from tree_structure;",
00349 0,0, &errmsg);
00350 if (rc) {
00351 cerr << fn << ": can't prepare tables: " << errmsg << endl;
00352 sqlite3_free(errmsg);
00353 sqlite3_close(sql3);
00354 return false;
00355 }
00356
00357 sqlite3_stmt *insert_node=0;
00358 sqlite3_stmt *insert_child=0;
00359 rc = sqlite3_prepare_v2(sql3, "insert into tree_nodes (ptr,id,mean) values (?,?,?)", -1, &insert_node, 0);
00360 assert(rc==0);
00361 rc = sqlite3_prepare_v2(sql3, "insert into tree_structure (parent,child) values (?,?)", -1, &insert_child, 0);
00362 assert(rc==0);
00363
00364 sqlite3_exec(sql3,"begin",0,0,0);
00365 if (save(sql3, insert_node, insert_child)) {
00366 sqlite3_bind_int64(insert_child, 1, 0);
00367 sqlite3_bind_int64(insert_child, 2, (sqlite3_int64) this);
00368 sqlite3_step(insert_child);
00369
00370 sqlite3_exec(sql3,"commit",0,0,0);
00371 } else{
00372 sqlite3_exec(sql3,"rollback",0,0,0);
00373 }
00374 sqlite3_finalize(insert_node);
00375 sqlite3_finalize(insert_child);
00376
00377 sqlite3_close(sql3);
00378 return true;
00379 }
00380
00381
00382 node_t *kmean_tree::load(sqlite3 *db)
00383 {
00384 if (db==0) return 0;
00385
00386 std::map<sqlite3_int64,node_t *> nodes;
00387 std::map<sqlite3_int64,node_t *>::iterator it;
00388
00389 char *errmsg=0;
00390
00391
00392 const char *query="select ptr,id,mean from tree_nodes";
00393 sqlite3_stmt *stmt=0;
00394 const char *tail=0;
00395 int rc = sqlite3_prepare_v2(db, query, -1, &stmt, &tail);
00396
00397 if (rc != SQLITE_OK) {
00398 cerr << "Error: " << sqlite3_errmsg(db) << endl;
00399 cerr << "While compiling: " << query << endl;
00400 return false;
00401 }
00402
00403 while (sqlite3_step(stmt) == SQLITE_ROW) {
00404 sqlite3_int64 ptr = sqlite3_column_int64(stmt, 0);
00405 node_t *node = new node_t;
00406 node->id = sqlite3_column_int(stmt, 1);
00407 nodes[ptr] = node;
00408 int sz = sqlite3_column_bytes(stmt,2);
00409 assert(sz == sizeof(node->mean));
00410 memcpy(&node->mean, sqlite3_column_blob(stmt, 2), sizeof(node->mean));
00411 }
00412 sqlite3_finalize(stmt);
00413
00414 query = "select parent,child from tree_structure";
00415 rc = sqlite3_prepare_v2(db, query, -1, &stmt, &tail);
00416 if (rc != SQLITE_OK) {
00417 cerr << "Error: " << sqlite3_errmsg(db) << endl;
00418 cerr << "While compiling: " << query << endl;
00419 return false;
00420 }
00421
00422
00423 node_t *root=0;
00424 while (sqlite3_step(stmt) == SQLITE_ROW) {
00425 sqlite3_int64 parent_id = sqlite3_column_int64(stmt, 0);
00426 sqlite3_int64 child_id = sqlite3_column_int64(stmt, 1);
00427
00428 it = nodes.find(child_id);
00429 assert(it!=nodes.end());
00430 node_t *child = it->second;
00431
00432 if (parent_id==0) {
00433 root = child;
00434 } else {
00435 it = nodes.find(parent_id);
00436 assert(it!=nodes.end());
00437 node_t *parent = it->second;
00438 bool written=false;
00439 for (int i=0; i<nb_branches; ++i) {
00440 if (parent->clusters[i] ==0) {
00441 parent->clusters[i] = child;
00442 written=true;
00443 break;
00444 }
00445 }
00446 assert(written);
00447 }
00448 }
00449 for (it = nodes.begin(); it!=nodes.end(); ++it) {
00450 if (it->second->id == 0 && it->second->clusters[0]==0) {
00451 cerr << "entry " << it->first << " has id=0 and no children!\n";
00452 }
00453 }
00454
00455 sqlite3_finalize(stmt);
00456 assert(root!=0);
00457 return root;
00458 }
00459
00460 node_t *kmean_tree::load(const char *filename)
00461 {
00462
00463
00464
00465
00466
00467
00468 FILE *f = fopen(filename,"rb");
00469 if (!f) {
00470 perror(filename);
00471 return false;
00472 }
00473 node_t *tree = new node_t;
00474 if (!tree->load(f)) {
00475 std::cerr << "loading failed\n";
00476 delete tree;
00477 tree = 0;
00478 }
00479 fclose(f);
00480 return tree;
00481 }
00482
00483 #include <sys/types.h>
00484 #include <sys/stat.h>
00485 #ifndef WIN32
00486 #include <sys/mman.h>
00487 #endif
00488 #include <fcntl.h>
00489
00490 #include <string>
00491 using namespace std;
00492 void save_node_images(string prefix, kmean_tree::node_t *node);
00493
00494 node_t *kmean_tree::build_from_data(const char *filename, int max_level, int min_elem, int stop)
00495 {
00496 #ifndef WIN32
00497 int fd = open(filename,O_RDONLY);
00498 if (fd<0) {
00499 perror(filename);
00500 return 0;
00501 }
00502
00503
00504 struct stat statbuf;
00505 if (fstat (fd,&statbuf) < 0) {
00506 perror(filename);
00507 close(fd);
00508 return 0;
00509 }
00510
00511 descr_file_packet *data = (descr_file_packet *) mmap(0,statbuf.st_size, PROT_READ, MAP_SHARED, fd, 0);
00512 if (data==(descr_file_packet *)-1) {
00513 perror("mmap");
00514 close(fd);
00515 return 0;
00516 }
00517 int ndata = statbuf.st_size / sizeof(descr_file_packet);
00518 #else
00519 FILE *file = fopen(filename, "rb");
00520 if (!file) return 0;
00521
00522 fseek(file, 0, SEEK_END);
00523 long size = ftell(file);
00524 int ndata = size/ sizeof(descr_file_packet);
00525 descr_file_packet *data = new descr_file_packet[ndata];
00526 fseek(file,0,SEEK_SET);
00527 if (fread(data, sizeof(descr_file_packet), ndata, file) != ndata) {
00528 cerr << filename << ": can not read data\n";
00529 fclose(file);
00530 return 0;
00531 }
00532 fclose(file);
00533 #endif
00534
00535 node_t *root = new node_t;
00536
00537
00538 if (stop>0 && stop<ndata) ndata=stop;
00539
00540 root->data.reserve(ndata);
00541 bool ok=true;
00542 for (int i=0; i<ndata; i++) {
00543 for (unsigned j=0; j<descriptor_size; j++) {
00544 if (!finite(data[i].d.descriptor[j])) {
00545 std::cout << "descriptor " << i << " has a problem in coord " << j << std::endl;
00546 ok=false;
00547 }
00548 }
00549 root->data.push_back(&data[i].d);
00550 }
00551 if (!ok) return 0;
00552
00553 std::cout << "Data loaded. Starting k-mean."<<std::endl;
00554 #ifdef _OPENMP
00555 omp_init_lock(&lock);
00556 #endif
00557 root->recursive_split(max_level, min_elem);
00558
00559 unsigned n = 0;
00560 root->assign_leaf_ids(&n);
00561 cout << "Tree has " << n << " leafs.\n";
00562
00563
00564
00565 #ifdef WIN32
00566 delete[] data;
00567 #else
00568 munmap(data,statbuf.st_size);
00569 close(fd);
00570 #endif
00571 return root;
00572
00573 }
00574
00575
00576 void node_t::print_summary(int depth) {
00577
00578 if (is_leaf()) return;
00579 for (unsigned i=0; i<nb_branches; i++) {
00580 for (int j=0;j<depth;j++) std::cout<< ".";
00581 std::cout << "C" << i <<": ";
00582 for (int j=0; j<16; j++) std::cout << " " << clusters[i]->mean.mean[j];
00583 std::cout<<std::endl;
00584 if (clusters[i])
00585 clusters[i]->print_summary(depth+1);
00586 }
00587 }
00588
00589 bool node_t::load(FILE *f)
00590 {
00591 unsigned char check, child;
00592 if (fread(&check, 1, 1, f) != 1) return false;
00593 if (check!= 0xab) {
00594 std::cerr << "check failed\n";
00595 return false;
00596 }
00597
00598 if ((fread(&id, sizeof(id), 1, f) != 1)
00599 || (fread(&mean, sizeof(mean), 1, f) != 1)
00600 || (fread(&child,1,1,f) != 1))
00601 {
00602 std::cerr << "error: can't read tree data!\n";
00603 return false;
00604 }
00605
00606
00607 if (child==0 && id==0) {
00608 std::cerr << "child has no id!\n";
00609 return false;
00610 }
00611 if (child>nb_branches) {
00612 std::cerr<< "error: too many branches!\n";
00613 return false;
00614 }
00615 for (unsigned i=0;i<child;i++) {
00616 clusters[i] = new node_t;
00617 if (!clusters[i]->load(f)) return false;
00618 }
00619 return true;
00620 }
00621
00622 bool node_t::save(sqlite3 *db, sqlite3_stmt *insert_node, sqlite3_stmt *insert_child)
00623 {
00624 sqlite3_bind_int64(insert_node, 1, (sqlite3_int64)this);
00625 sqlite3_bind_int(insert_node, 2, id);
00626 sqlite3_bind_blob(insert_node, 3, &mean, sizeof(mean), SQLITE_STATIC);
00627 if(sqlite3_step(insert_node)!=SQLITE_DONE) {
00628 printf("%s:%d: error: %s\n", __FILE__, __LINE__, sqlite3_errmsg(db));
00629 return false;
00630 }
00631 sqlite3_reset(insert_node);
00632 for (unsigned i=0; i<nb_branches; i++) {
00633 if (clusters[i]) {
00634 sqlite3_bind_int64(insert_child, 1, (sqlite3_int64)this);
00635 sqlite3_bind_int64(insert_child, 2, (sqlite3_int64)clusters[i]);
00636 if(sqlite3_step(insert_child)!=SQLITE_DONE) {
00637 printf("%s:%d: error: %s\n", __FILE__, __LINE__, sqlite3_errmsg(db));
00638 return false;
00639 }
00640 sqlite3_reset(insert_child);
00641 }
00642 }
00643 for (unsigned i=0; i<nb_branches; i++) {
00644 if (clusters[i])
00645 if (!clusters[i]->save(db,insert_node,insert_child)) return false;
00646 }
00647 return true;
00648 }
00649
00650 bool node_t::save(FILE *f)
00651 {
00652 unsigned char check=0xab, child = 0;
00653 for (unsigned i=0;i<nb_branches;i++) if (clusters[i]) ++child;
00654
00655 if (fwrite(&check,sizeof(check),1,f)!=1) return false;
00656 if (fwrite(&id,sizeof(id), 1, f) != 1) return false;
00657 if (fwrite(&mean,sizeof(mean), 1, f) != 1) return false;
00658 if (fwrite(&child,sizeof(child),1,f)!=1) return false;
00659 if (child == 0 && id==0) {
00660 std::cerr << "Warning: writing a tree file that contains a leaf with id=0!\n";
00661 }
00662 if (child) {
00663 for (unsigned i=0; i<nb_branches; i++) {
00664 if (clusters[i]) {
00665 if (!clusters[i]->save(f)) return false;
00666 }
00667 }
00668 }
00669 return true;
00670 }
00671
00672
00673
00674 #include <highgui.h>
00675 #include <sstream>
00676 template < class T >
00677 string ToString(const T &arg)
00678 {
00679 ostringstream out;
00680
00681 out << arg;
00682
00683 return(out.str());
00684 }
00685 void save_patch(const string &name, float *f)
00686 {
00687 CvMat mat;
00688 cvInitMatHeader(&mat, 16, 16, CV_32FC1, f);
00689 double min,max;
00690 cvMinMaxLoc(&mat, &min, &max);
00691 CvMat *m2 = cvCreateMat(16,16,CV_8UC1);
00692 cvCvtScale(&mat, m2, 255/(max-min), -min*255/(max-min));
00693 cvSaveImage(name.c_str(), m2);
00694 cvReleaseMat(&m2);
00695 cout << name << ": " << min << ", " << max << endl;
00696 }
00697
00698 void save_node_images(string prefix, kmean_tree::node_t *node)
00699 {
00700 if (!node) return;
00701
00702 for (unsigned i=0; i<3; i++) {
00703 if (i<node->data.size()) {
00704 save_patch(prefix + "D" + ToString(i) + ".png", node->data[i]->descriptor);
00705 }
00706 }
00707
00708 string name = prefix + ".png";
00709
00710 save_patch(name,node->mean.mean);
00711
00712 for (unsigned i=0;i<kmean_tree::nb_branches;i++) {
00713 save_node_images(prefix + "_" + ToString(i), node->clusters[i]);
00714 }
00715 }
00716