Commented npz_converter.h

This commit is contained in:
Lane Schwartz 2016-09-18 16:25:30 +02:00
parent 04fb8734a2
commit cac6318818

View File

@ -3,29 +3,86 @@
#include "cnpy/cnpy.h" #include "cnpy/cnpy.h"
#include "tensor.h" #include "tensor.h"
/**
* Loads model data stored in a npz file,
* enabling it to later be stored in standard Marian data structures.
*
* Note: this class makes use of the 3rd-party class <code>npy</code>.
*/
class NpzConverter { class NpzConverter {
// Private inner classes of the NpzConverter class
private: private:
/**
* Wraps npy data such that the underlying matrix shape and
* matrix data are made accessible.
*/
class NpyMatrixWrapper { class NpyMatrixWrapper {
public: public:
/**
* Constructs a wrapper around an underlying npy data structure,
* enabling the underlying data to be accessed as a matrix.
*
* @param npy the underlying data
*/
NpyMatrixWrapper(const cnpy::NpyArray& npy) NpyMatrixWrapper(const cnpy::NpyArray& npy)
: npy_(npy) {} : npy_(npy) {}
/**
* Returns the total number of elements in the underlying matrix.
*
* @return the total number of elements in the underlying matrix
*/
size_t size() const { size_t size() const {
return size1() * size2(); return size1() * size2();
} }
/**
* Returns a pointer to the raw data that underlies the matrix.
*
* @return a pointer to the raw data that underlies the matrix
*/
float* data() const { float* data() const {
return (float*)npy_.data; return (float*)npy_.data;
} }
/**
* Given the index (i, j) of a matrix element,
* this operator returns the float value from the underlying npz data
* that is stored in the matrix.
*
* XXX: Marcin, is the following correct? Or do I have the row/column labels swapped?
*
* @param i Index of a column in the matrix
* @param j Index of a row in the matrix
*
* @return the float value stored at column i, row j of the matrix
*/
float operator()(size_t i, size_t j) const { float operator()(size_t i, size_t j) const {
return ((float*)npy_.data)[i * size2() + j]; return ((float*)npy_.data)[i * size2() + j];
} }
/**
* Returns the number of columns in the matrix.
*
* XXX: Marcin, is this following correct? Or do I have the row/column labels swapped?
*
* @return the number of columns in the matrix
*/
size_t size1() const { size_t size1() const {
return npy_.shape[0]; return npy_.shape[0];
} }
/**
* Returns the number of rows in the matrix.
*
* XXX: Marcin, is this following correct? Or do I have the row/column labels swapped?
*
* @return the number of rows in the matrix
*/
size_t size2() const { size_t size2() const {
if(npy_.shape.size() == 1) if(npy_.shape.size() == 1)
return 1; return 1;
@ -34,25 +91,50 @@ class NpzConverter {
} }
private: private:
const cnpy::NpyArray& npy_;
};
/** Instance of the underlying (3rd party) data structure. */
const cnpy::NpyArray& npy_;
}; // End of NpyMatrixWrapper class
// Public methods of the NpzConverter class
public: public:
/**
* Constructs an object that reads npz data from a file.
*
* @param file Path to file containing npz data
*/
NpzConverter(const std::string& file) NpzConverter(const std::string& file)
: model_(cnpy::npz_load(file)), : model_(cnpy::npz_load(file)),
destructed_(false) { destructed_(false) {
} }
/**
* Destructs the model that underlies this NpzConverter object,
* if that data has not already been destructed.
*/
~NpzConverter() { ~NpzConverter() {
if(!destructed_) if(!destructed_)
model_.destruct(); model_.destruct();
} }
/**
* Destructs the model that underlies this NpzConverter object,
* and marks that data as having been destructed.
*/
void Destruct() { void Destruct() {
model_.destruct(); model_.destruct();
destructed_ = true; destructed_ = true;
} }
/**
* Loads data corresponding to a search key into the provided vector.
*
* @param key Search key XXX Marcin, what type of thing is "key"? What are we searching for here?
* @param data Container into which data will be loaded XXX Lane, is there a way in Doxygen to mark and inout variable?
* @param shape Shape object into which the number of rows and columns of the vectors will be stored
*/
void Load(const std::string& key, std::vector<float>& data, marian::Shape& shape) const { void Load(const std::string& key, std::vector<float>& data, marian::Shape& shape) const {
auto it = model_.find(key); auto it = model_.find(key);
if(it != model_.end()) { if(it != model_.end()) {
@ -71,7 +153,13 @@ class NpzConverter {
} }
} }
// Private member data of the NpzConverter class
private: private:
/** Underlying npz data */
cnpy::npz_t model_; cnpy::npz_t model_;
/** Indicates whether the underlying data has been destructed. */
bool destructed_; bool destructed_;
};
}; // End of NpzConverter class