00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #ifndef IDCLUSTER_H
00021 #define IDCLUSTER_H
00022
00023 #include <stdio.h>
00024 #include <map>
00025 #include <set>
00026 #include "vecmap.h"
00027 #include "preallocated.h"
00028 #include "sqlite3.h"
00029
00030 typedef unsigned cluster_id_t;
00031
00035 class id_cluster {
00036 public:
00037
00038 id_cluster() : total(0), id(0), weighted_sum(0) {}
00039 virtual ~id_cluster() {}
00040
00041 typedef std::map<unsigned, unsigned> uumap;
00042 uumap histo;
00043 unsigned total;
00044 cluster_id_t id;
00045 float weighted_sum;
00046
00047
00048 unsigned add(unsigned id, int amount=1);
00049 float dotprod(const id_cluster &a) const;
00050 void print() const;
00051 void clear();
00052
00053 unsigned get_freq(unsigned id);
00054 float get_proba(unsigned id) { return (float)get_freq(id)/(float)total; }
00055 };
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 typedef std::map<id_cluster *, double> cluster_score_map;
00067
00068 class id_cluster_collection {
00069 public:
00070
00071 typedef std::set<id_cluster *> cluster_set;
00072
00073 typedef vecmap<id_cluster *, float> cluster_map;
00074
00075 typedef std::map<unsigned, cluster_map> id2cluster_map;
00076
00077 enum query_flags { QUERY_FREQ=0, QUERY_NORMALIZED_FREQ=1, QUERY_IDF=2, QUERY_MIN_FREQ=4, QUERY_BIN_FREQ = 8, QUERY_IDF_NORMALIZED=3 };
00078 id_cluster_collection(query_flags flags);
00079 virtual ~id_cluster_collection();
00080
00081 void reduce(float threshold);
00082
00083 void add_cluster(id_cluster *c);
00084 void remove_cluster(id_cluster *c);
00085 void update_cluster(id_cluster *c, unsigned id, int amount);
00086
00087
00088 void merge_clusters(id_cluster *a, id_cluster *b);
00089
00090 void cmp_best_clusters();
00091 unsigned get_best_cluster(unsigned id);
00092 id_cluster *get_best_cluster(id_cluster *c, float *dist);
00093
00094
00095 void print();
00096 bool save(const char *fn);
00097 bool save(FILE *f);
00098 bool save(sqlite3 *db, const char *tablename=0);
00099 bool load(const char *fn);
00100 bool load(sqlite3 *db, const char *tablename=0);
00101 void clear();
00102 float idf(unsigned id);
00103 float idf(id2cluster_map::iterator &id_it);
00104
00105 cluster_set clusters;
00106
00108 id2cluster_map id2cluster;
00109 std::map<unsigned, id_cluster *> best_cluster;
00110
00111 protected:
00112
00113 struct cluster_dist_t {
00114 id_cluster *a, *b;
00115 float d;
00116
00117 cluster_dist_t(id_cluster *a, id_cluster *b, float d);
00118 bool operator < (const cluster_dist_t &c) const;
00119 };
00120
00121
00122 std::set<cluster_dist_t> distance_matrix;
00123
00124 public:
00125 void get_scores(id_cluster *c, cluster_score_map &scores, id_cluster **best_c=0, float *_best_s=0);
00126
00127 void set_query_rules(query_flags flags);
00128
00129 protected:
00130 void build_distance_matrix(float threshold);
00131 void add_to_distance_matrix(id_cluster *c, float threshold);
00132 void remove_from_distance_matrix(id_cluster *c);
00133
00134 int version;
00135 public:
00136 int get_version() const { return version; }
00137
00138 private:
00139 query_flags flags;
00140 bool is_idf_normalized;
00141
00142 friend class incremental_query;
00143 };
00144
00145 class incremental_query {
00146 public:
00147
00148 struct ranked_cluster {
00149 id_cluster * const c;
00150 const float score;
00151 ranked_cluster(id_cluster *c, float s) : c(c), score(s) {}
00152
00153 bool operator< (const ranked_cluster &a) const {
00154 if (a.score < score) return true;
00155 if (score < a.score) return false;
00156 return c<a.c;
00157 }
00158 };
00159 typedef std::set<ranked_cluster> ranked_cluster_set;
00160 typedef ranked_cluster_set::iterator iterator;
00161 ranked_cluster_set results;
00162
00164 incremental_query(id_cluster_collection *db);
00165
00167 void modify(unsigned id, int amount=1);
00168
00172 void set(id_cluster *c);
00173
00174 iterator sort_results(unsigned max_results=1);
00175 iterator sort_results_min_ratio(float ratio);
00176
00177 iterator begin() { return results.begin(); }
00178 iterator end() { return results.end(); }
00179
00180 void clear();
00181 id_cluster *get_best(float *score);
00182
00183
00184 cluster_score_map scores;
00185 id_cluster query_cluster;
00186 id_cluster_collection *database;
00187
00188
00189 id_cluster_collection::query_flags get_flags() { return database->flags; }
00190
00191
00192 void set_all_flags(id_cluster_collection::query_flags flag) { clear(); database->set_query_rules(flag); }
00193
00194 int version;
00195 protected:
00196 int flags;
00197 };
00198
00199
00200 #endif
00201