blob: ccae9396b34af3415530bcba3cb47778074fcfff [file] [log] [blame]
/* AUTORIGHTS
Copyright (C) 2007 Princeton University
This file is part of Ferret Toolkit.
Ferret Toolkit is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2, or (at your option)
any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software Foundation,
Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/
#include <cass.h>
#include <values.h>
#include "chunk_cnt.inc"
#include "emd.h"
cass_dist_t dist_trivial_dist (cass_size_t n, void *p1, void *p2, void *p3)
{
return 0.0;
}
#define DIST_SIMPLE_METHODS(type,_class) \
static int dist_##type##_construct (void **_dist, const char *param) \
{ \
cass_vec_dist_t *dist = type_calloc(cass_vec_dist_t, 1); \
dist->refcnt++; \
dist->__class = &_class; \
*_dist = dist; \
return 0; \
} \
\
static int dist_##type##_restore (void **_dist, CASS_FILE *fin) \
{ \
cass_vec_dist_t *dist; \
dist_##type##_construct(_dist, NULL); \
dist = *_dist; \
dist->name = cass_read_pchar(fin); \
if (dist->name == NULL) \
{ \
free(dist); \
return CASS_ERR_IO; \
} \
return 0; \
}
static int dist_simple_describe (void *_dist, CASS_FILE *fout)
{
cass_vec_dist_t *dist = _dist; \
cass_printf(fout, "NAME:\t%s\n", dist->name);
cass_printf(fout, "CLASS:\t%s\n", dist->__class->name);
return 0;
}
static int dist_simple_checkpoint (void *_dist, CASS_FILE *fout)
{
cass_vec_dist_t *dist = _dist;
return cass_write_pchar(dist->name, fout);
}
static void dist_simple_free (void *_dist)
{
cass_vec_dist_t *dist = _dist;
dist->refcnt--;
if (dist->refcnt == 0)
{
if (dist->name != NULL) free(dist->name);
free(dist);
}
}
DIST_SIMPLE_METHODS(trivial, vec_dist_trivial)
cass_vec_dist_class_t vec_dist_trivial =
{
.name = "trivial",
.vec_type = CASS_ANY,
.type = CASS_VEC_DIST_TYPE_TRIVIAL,
.dist = dist_trivial_dist,
.describe = dist_simple_describe,
.construct = dist_trivial_construct,
.checkpoint = dist_simple_checkpoint,
.restore = dist_trivial_restore,
.free = dist_simple_free,
};
DIST_SIMPLE_METHODS(L1_int, vec_dist_L1_int)
cass_dist_t __dist_L1_int32 (cass_size_t D, const int32_t *v1, const int32_t *v2)
{
return (cass_dist_t) dist_L1_int32_t(D, v1, v2);
}
cass_vec_dist_class_t vec_dist_L1_int =
{
.name = "L1_int",
.vec_type = CASS_VEC_INT,
.type = CASS_VEC_DIST_TYPE_L1,
.dist = __dist_L1_int32,
.describe = dist_simple_describe,
.construct = dist_L1_int_construct,
.checkpoint = dist_simple_checkpoint,
.restore = dist_L1_int_restore,
.free = dist_simple_free,
};
DIST_SIMPLE_METHODS(L2_int, vec_dist_L2_int)
cass_dist_t __dist_L2_int32 (cass_size_t D, const int32_t *v1, const int32_t *v2)
{
return (cass_dist_t) dist_L2_int32_t(D, v1, v2);
}
cass_vec_dist_class_t vec_dist_L2_int =
{
.name = "L2_int",
.vec_type = CASS_VEC_INT,
.type = CASS_VEC_DIST_TYPE_L2,
.dist = __dist_L2_int32,
.describe = dist_simple_describe,
.construct = dist_L2_int_construct,
.checkpoint = dist_simple_checkpoint,
.restore = dist_L2_int_restore,
.free = dist_simple_free,
};
DIST_SIMPLE_METHODS(L1_float, vec_dist_L1_float)
cass_vec_dist_class_t vec_dist_L1_float =
{
.name = "L1_float",
.vec_type = CASS_VEC_FLOAT,
.type = CASS_VEC_DIST_TYPE_L1,
.dist = dist_L1_float,
.describe = dist_simple_describe,
.construct = dist_L1_float_construct,
.checkpoint = dist_simple_checkpoint,
.restore = dist_L1_float_restore,
.free = dist_simple_free,
};
DIST_SIMPLE_METHODS(L2_float, vec_dist_L2_float)
cass_vec_dist_class_t vec_dist_L2_float =
{
.name = "L2_float",
.vec_type = CASS_VEC_FLOAT,
.type = CASS_VEC_DIST_TYPE_L1,
.dist = dist_L2_float,
.describe = dist_simple_describe,
.construct = dist_L2_float_construct,
.checkpoint = dist_simple_checkpoint,
.restore = dist_L2_float_restore,
.free = dist_simple_free,
};
DIST_SIMPLE_METHODS(cos_float, vec_dist_cos_float)
cass_vec_dist_class_t vec_dist_cos_float =
{
.name = "cosine",
.vec_type = CASS_VEC_FLOAT,
.type = CASS_VEC_DIST_TYPE_COS,
.dist = dist_cos_float,
.describe = dist_simple_describe,
.construct = dist_cos_float_construct,
.checkpoint = dist_simple_checkpoint,
.restore = dist_cos_float_restore,
.free = dist_simple_free,
};
DIST_SIMPLE_METHODS(hamming, vec_dist_hamming)
cass_dist_t __dist_hamming (cass_size_t n, const chunk_t *c1, const chunk_t *c2)
{
return (cass_dist_t) dist_hamming(n, c1, c2);
}
cass_vec_dist_class_t vec_dist_hamming =
{
.name = "hamming",
.vec_type = CASS_VEC_BIT,
.type = CASS_VEC_DIST_TYPE_HAMMING,
.dist = __dist_hamming,
.describe = dist_simple_describe,
.construct = dist_hamming_construct,
.checkpoint = dist_simple_checkpoint,
.restore = dist_hamming_restore,
.free = dist_simple_free,
};
#define SDIST_SIMPLE_METHODS(type,_class) \
static int sdist_##type##_construct (void **_dist, const char *param) \
{ \
cass_vecset_dist_t *dist = type_calloc(cass_vecset_dist_t, 1); \
dist->refcnt++; \
dist->__class = &_class; \
*_dist = dist; \
return 0; \
} \
\
static int sdist_##type##_restore (void **_dist, CASS_FILE *fin) \
{ \
cass_vecset_dist_t *dist; \
sdist_##type##_construct(_dist, NULL); \
dist = *_dist; \
dist->name = cass_read_pchar(fin); \
if (dist->name == NULL) \
{ \
free(dist); \
return CASS_ERR_IO; \
} \
return 0; \
}
static int sdist_simple_describe (void *_dist, CASS_FILE *fout)
{
cass_vecset_dist_t *dist = _dist;
cass_printf(fout, "NAME:\t%s\n", dist->name);
cass_printf(fout, "CLASS:\ttrivial\n");
return 0;
}
static int sdist_simple_checkpoint (void *_dist, CASS_FILE *fout)
{
cass_vecset_dist_t *dist = _dist;
return cass_write_pchar(dist->name, fout);
}
static void sdist_simple_free (void *_dist)
{
cass_vecset_dist_t *dist = _dist;
dist->refcnt--;
if (dist->refcnt == 0)
{
free(dist->name);
free(dist);
}
}
SDIST_SIMPLE_METHODS(trivial, vecset_dist_trivial);
cass_dist_t sdist_trivial (cass_dataset_t *ds1, cass_vecset_id_t p1, cass_dataset_t *ds2, cass_vecset_id_t p2, cass_vec_dist_t *vec_dist, void *p)
{
return 0.0;
}
cass_vecset_dist_class_t vecset_dist_trivial =
{
.name = "trivial",
.vecset_type = CASS_ANY,
.type = CASS_VECSET_DIST_TYPE_TRIVIAL,
.dist = sdist_trivial,
.describe = sdist_simple_describe,
.construct = sdist_trivial_construct,
.checkpoint = sdist_simple_checkpoint,
.restore = sdist_trivial_restore,
.free = sdist_simple_free,
};
SDIST_SIMPLE_METHODS(single, vecset_dist_single);
cass_dist_t sdist_single (cass_dataset_t *ds1, cass_vecset_id_t p1, cass_dataset_t *ds2, cass_vecset_id_t p2, cass_vec_dist_t *vec_dist, void *p)
{
cass_vec_t *v1, *v2;
v1 = ds1->vec + ds1->vec_size * ds1->vecset[p1].start_vecid;
v2 = ds2->vec + ds2->vec_size * ds2->vecset[p2].start_vecid;
return vec_dist->__class->dist(ds1->vec_dim, v1->u.data, v2->u.data, vec_dist);
}
cass_vecset_dist_class_t vecset_dist_single =
{
.name = "single",
.vecset_type = CASS_ANY,
.type = CASS_VECSET_DIST_TYPE_SINGLE,
.dist = sdist_single,
.describe = sdist_simple_describe,
.construct = sdist_single_construct,
.checkpoint = sdist_simple_checkpoint,
.restore = sdist_single_restore,
.free = sdist_simple_free,
};
SDIST_SIMPLE_METHODS(emd, vecset_dist_emd);
cass_dist_t sdist_emd (cass_dataset_t *ds1, cass_vecset_id_t p1, cass_dataset_t *ds2, cass_vecset_id_t p2, cass_vec_dist_t *vec_dist, void *p)
{
cass_vecset_t *vecset1;
cass_vecset_t *vecset2;
signature_t sig1;
signature_t sig2;
cass_vec_t *vec;
int i;
assert(ds1->vec_dim == ds2->vec_dim);
vecset1 = &ds1->vecset[p1];
vecset2 = &ds2->vecset[p2];
sig1.n = vecset1->num_regions;
sig2.n = vecset2->num_regions;
sig1.Features = alloca(sig1.n * sizeof *sig1.Features);
sig1.Weights = alloca(sig1.n * sizeof *sig1.Weights);
sig2.Features = alloca(sig2.n * sizeof *sig2.Features);
sig2.Weights = alloca(sig2.n * sizeof *sig2.Weights);
vec = (void *)ds1->vec + ds1->vec_size * vecset1->start_vecid;
for (i = 0; i < sig1.n; i++)
{
sig1.Features[i] = vec->u.float_data;
sig1.Weights[i] = vec->weight;
vec = (void *)vec + ds1->vec_size;
}
vec = (void *)ds2->vec + ds2->vec_size * vecset2->start_vecid;
for (i = 0; i < sig2.n; i++)
{
sig2.Features[i] = vec->u.float_data;
sig2.Weights[i] = vec->weight;
vec = (void *)vec + ds2->vec_size;
}
return emd(&sig1, &sig2, vec_dist->__class->dist, ds1->vec_dim, vec_dist, NULL, NULL);
}
cass_vecset_dist_class_t vecset_dist_emd =
{
.name = "emd",
.vecset_type = CASS_ANY,
.type = CASS_VECSET_DIST_TYPE_EMD,
.dist = sdist_emd,
.describe = sdist_simple_describe,
.construct = sdist_emd_construct,
.checkpoint = sdist_simple_checkpoint,
.restore = sdist_emd_restore,
.free = sdist_simple_free,
};
SDIST_SIMPLE_METHODS(myemd, vecset_dist_myemd);
extern float tp_solve (int nrow, float *row, int ncol, float *col, float **cost);
cass_dist_t sdist_myemd (cass_dataset_t *ds1, cass_vecset_id_t p1, cass_dataset_t *ds2, cass_vecset_id_t p2, cass_vec_dist_t *vec_dist, void *p)
{
cass_vecset_t *vecset1, *vecset2;
cass_vec_t *vec1, *vec2;
float *row, *col;
float **cost;
float d, srow, scol, ss;
int nrow, ncol;
int i, j;
assert(ds1->vec_dim == ds2->vec_dim);
vecset1 = &ds1->vecset[p1];
vecset2 = &ds2->vecset[p2];
nrow = vecset1->num_regions;
ncol = vecset2->num_regions;
row = (float *)alloca((nrow + 1) * sizeof (float));
col = (float *)alloca((ncol + 1) * sizeof (float));
srow = scol = 0;
vec1 = (void *)ds1->vec + ds1->vec_size * vecset1->start_vecid;
for (i = 0; i < nrow; i++)
{
row[i] = vec1->weight;
srow += row[i];
vec1 = (void *)vec1 + ds1->vec_size;
}
vec2 = (void *)ds2->vec + ds2->vec_size * vecset2->start_vecid;
for (i = 0; i < ncol; i++)
{
col[i] = vec2->weight;
scol += col[i];
vec2 = (void *)vec2 + ds2->vec_size;
}
cost = type_matrix_alloc(float, nrow + 1, ncol + 1);
vec1 = (void *)ds1->vec + ds1->vec_size * vecset1->start_vecid;
for (i = 0; i < nrow; i++)
{
vec2 = (void *)ds2->vec + ds2->vec_size * vecset2->start_vecid;
for (j = 0; j < ncol; j++)
{
cost[i][j] = vec_dist->__class->dist(ds1->vec_dim, vec1->u.float_data, vec2->u.float_data, vec_dist);
vec2 = (void *)vec2 + ds2->vec_size;
}
vec1 = (void *)vec1 + ds1->vec_size;
}
#define ABS_ERR 0.0
if (srow > scol)
{
ss = scol;
col[ncol] = srow - scol;
if (col[ncol] > ABS_ERR) ncol++;
}
else
{
ss = srow;
row[nrow] = scol - srow;
if (row[nrow] > ABS_ERR) nrow++;
}
d = tp_solve(nrow, row, ncol, col, cost);
matrix_free(cost);
return d / ss;
}
cass_vecset_dist_class_t vecset_dist_myemd =
{
.name = "myemd",
.vecset_type = CASS_ANY,
.type = CASS_VECSET_DIST_TYPE_EMD,
.dist = sdist_myemd,
.describe = sdist_simple_describe,
.construct = sdist_myemd_construct,
.checkpoint = sdist_simple_checkpoint,
.restore = sdist_myemd_restore,
.free = sdist_simple_free,
};