blob: ab5e58d0af4e9345bb930abfa07f3bbaebf00e1e [file] [log] [blame]
/* AUTORIGHTS
Copyright (C) 2007 Princeton University
This file is part of Ferret Toolkit.
Ferret Toolkit is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2, or (at your option)
any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software Foundation,
Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/
#include <string.h>
#include <cass.h>
#define MEM_OVERHEAD 3/2
static inline int cass_read_vecset (cass_vecset_t *vecset, size_t nmemb, CASS_FILE *in) {
int n = fread(vecset, sizeof(cass_vecset_t), nmemb, in);
if (!isLittleEndian()) {
int i;
assert(sizeof(vecset[0].start_vecid) == sizeof(uint32_t));
for (i = 0; i < n; i++) {
vecset[i].num_regions = bswap_int32(vecset[i].num_regions);
vecset[i].start_vecid = bswap_int32(vecset[i].start_vecid);
}
}
return n;
}
int cass_dataset_init (cass_dataset_t *ds, cass_size_t vec_size, cass_size_t vec_dim, uint32_t flags)
{
memset(ds, 0, sizeof *ds);
ds->vec_size = vec_size;
ds->vec_dim = vec_dim;
ds->flags = flags;
return 0;
}
int cass_dataset_release (cass_dataset_t *ds)
{
if (ds->vec != NULL) free(ds->vec);
if (ds->vec != NULL) free(ds->vecset);
ds->vec = NULL;
ds->vecset = NULL;
ds->max_vec = ds->max_vecset = 0;
ds->loaded = 0;
return 0;
}
#define MEM_1G (1024*1024*1024)
cass_size_t grow (cass_size_t from, cass_size_t to)
{
if (from >= to) return from;
if (from == 0) from++;
while ((from < to) && (from < MEM_1G)) from *= 2;
while (from < to) from += MEM_1G / 2;
return from;
}
int cass_dataset_grow (cass_dataset_t *ds, cass_size_t num_vecset, cass_size_t num_vec)
{
if ((ds->flags & CASS_DATASET_VEC) && (ds->max_vec < num_vec))
{
ds->max_vec = grow(ds->max_vec, num_vec);
ds->vec = (cass_vec_t *)realloc(ds->vec, ds->vec_size * ds->max_vec);
if (ds->vec == NULL) return CASS_ERR_OUTOFMEM;
}
if ((ds->flags & CASS_DATASET_VECSET) && (ds->max_vecset < num_vecset))
{
ds->max_vecset = grow(ds->max_vecset, num_vecset);
ds->vecset = (cass_vecset_t *)realloc(ds->vecset, sizeof (cass_vecset_t) * ds->max_vecset);
if (ds->vecset == NULL) return CASS_ERR_OUTOFMEM;
}
return 0;
}
/* if src->flags & CASS_DATASET_VECSET == 0 then start & num are for vectors,
* otherwise they are for vecsets. */
int cass_dataset_merge (cass_dataset_t *ds, const cass_dataset_t *src, cass_vecset_id_t start, cass_vecset_id_t num, cass_dataset_map_t map, void *map_param)
{
void *vec;
void *src_vec;
int i;
cass_vec_id_t start_vec, num_vec;
int parent_delta = 0;
start_vec = num_vec = 0;
assert(ds->loaded);
assert(src->loaded);
if (ds->flags & CASS_DATASET_VECSET)
/* copy the vecset data */
{
assert(src->flags & CASS_DATASET_VECSET);
parent_delta = ds->num_vecset - start;
if (ds->num_vecset + num > ds->max_vecset)
{
ds->max_vecset = grow(ds->max_vecset, ds->num_vecset + num);
ds->vecset = (cass_vecset_t *)realloc(ds->vecset, sizeof (cass_vecset_t) * ds->max_vecset);
if (ds->vecset == NULL) return CASS_ERR_OUTOFMEM;
}
num_vec = 0;
for (i = 0; i < num; i++)
{
ds->vecset[ds->num_vecset].num_regions = src->vecset[start + i].num_regions;
ds->vecset[ds->num_vecset].start_vecid = ds->num_vec + num_vec;
ds->num_vecset++;
num_vec += src->vecset[start + i].num_regions;
}
}
else
{
ds->num_vecset += num;
num_vec = 0;
for (i = 0; i < num; i++) num_vec += src->vecset[start + i].num_regions;
}
/* deal with vecs */
if (ds->flags & CASS_DATASET_VEC)
{
if (src->flags & CASS_DATASET_VECSET)
{
start_vec = src->vecset[start].start_vecid;
}
else
{
num_vec = num;
start_vec = start;
parent_delta = 0;
}
if (ds->num_vec + num_vec > ds->max_vec)
{
ds->max_vec = grow(ds->max_vec, ds->num_vec + num_vec);
ds->vec = (cass_vec_t *)realloc(ds->vec, ds->vec_size * ds->max_vec);
if (ds->vec == NULL) return CASS_ERR_OUTOFMEM;
}
vec = ds->vec + ds->num_vec * ds->vec_size;
src_vec = src->vec + start_vec * src->vec_size;
if (map == NULL)
{
map = memcpy;
map_param = (void *)ds->vec_size;
}
for (i = 0; i < num_vec; i++)
{
((cass_vec_t *)vec)->weight = ((cass_vec_t *)src_vec)->weight;
((cass_vec_t *)vec)->parent = ((cass_vec_t *)src_vec)->parent + parent_delta;
map(((cass_vec_t *)vec)->u.data, ((cass_vec_t *)src_vec)->u.data, map_param);
vec += ds->vec_size;
src_vec += src->vec_size;
ds->num_vec++;
}
}
else
{
ds->num_vec += num_vec;
}
return 0;
}
int cass_dataset_checkpoint (cass_dataset_t *ds, CASS_FILE *out)
{
if (cass_write_uint32(&ds->flags, 1, out) != 1) return CASS_ERR_IO;
if (cass_write_size(&ds->vec_size, 1, out) != 1) return CASS_ERR_IO;
if (cass_write_size(&ds->vec_dim, 1, out) != 1) return CASS_ERR_IO;
if (cass_write_size(&ds->num_vec, 1, out) != 1) return CASS_ERR_IO;
if (cass_write_size(&ds->num_vecset, 1, out) != 1) return CASS_ERR_IO;
return 0;
}
int cass_dataset_restore (cass_dataset_t *ds, CASS_FILE *in)
{
memset(ds, 0, sizeof *ds);
if (cass_read_uint32(&ds->flags, 1, in) != 1) return CASS_ERR_IO;
if (cass_read_size(&ds->vec_size, 1, in) != 1) return CASS_ERR_IO;
if (cass_read_size(&ds->vec_dim, 1, in) != 1) return CASS_ERR_IO;
if (cass_read_size(&ds->num_vec, 1, in) != 1) return CASS_ERR_IO;
if (cass_read_size(&ds->num_vecset, 1, in) != 1) return CASS_ERR_IO;
return 0;
}
int cass_dataset_load (cass_dataset_t *ds, CASS_FILE *in, cass_vec_type_t vec_type)
{
int ret;
assert(!ds->loaded);
assert(ds->vec == NULL);
assert(ds->vecset == NULL);
if (ds->flags & CASS_DATASET_VECSET)
{
ret = CASS_ERR_OUTOFMEM;
ds->max_vecset = ds->num_vecset * MEM_OVERHEAD;
ds->vecset = (cass_vecset_t *)malloc(sizeof(cass_vecset_t) * ds->max_vecset);
if (ds->vecset == NULL) goto err;
ret = CASS_ERR_IO;
if (cass_read_vecset(ds->vecset, ds->num_vecset, in) != ds->num_vecset) goto err;
}
if (ds->flags & CASS_DATASET_VEC)
{
ds->max_vec = ds->num_vec;// * MEM_OVERHEAD;
ret = CASS_ERR_OUTOFMEM;
ds->vec = (cass_vec_t *)malloc(ds->vec_size * ds->max_vec);
if (ds->vec == NULL) goto err;
ret = CASS_ERR_IO;
if (isLittleEndian()) {
if (cass_read(ds->vec, ds->vec_size, ds->num_vec, in) != ds->num_vec) goto err;
}
else {
if (vec_type == CASS_VEC_INT) {
unsigned cnt = ds->vec_size / sizeof(int32_t) * ds->num_vec;
if (cass_read_int32(ds->vec, cnt, in) != cnt) goto err;
}
else if (vec_type == CASS_VEC_FLOAT) {
unsigned cnt = ds->vec_size / sizeof(float) * ds->num_vec;
if (cass_read_float(ds->vec, cnt, in) != cnt) goto err;
}
else if (vec_type == CASS_VEC_BIT) {
if (cass_read(ds->vec, ds->vec_size, ds->num_vec, in) != ds->num_vec) goto err;
}
else assert(0);
}
}
ds->loaded = 1;
return 0;
err:
cass_dataset_release(ds);
return ret;
}
int cass_dataset_dump (cass_dataset_t *ds, CASS_FILE *out)
{
assert(ds->loaded);
if (ds->flags & CASS_DATASET_VECSET)
{
if (cass_write(ds->vecset, sizeof (cass_vecset_t), ds->num_vecset, out) != ds->num_vecset) return CASS_ERR_IO;
}
if (ds->flags & CASS_DATASET_VEC)
{
if (cass_write(ds->vec, ds->vec_size, ds->num_vec, out) != ds->num_vec) return CASS_ERR_IO;
}
return 0;
}
cass_vecset_id_t cass_dataset_vec2vecset (cass_dataset_t *ds, cass_vec_id_t id)
{
cass_vecset_id_t l, r, m;
l = 0;
r = ds->num_vecset - 1;
assert(ds->flags & CASS_DATASET_VECSET);
assert(id < ds->num_vec);
if (id >= ds->vecset[r].start_vecid) return r;
for (;;)
{
if (r == l + 1) return l;
m = (l + r) / 2;
if (id < ds->vecset[m].start_vecid) r = m;
if (id >= ds->vecset[m].start_vecid) l = m;
}
assert(0);
}