mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Commented npz_converter.h
This commit is contained in:
parent
04fb8734a2
commit
cac6318818
@ -3,29 +3,86 @@
|
|||||||
#include "cnpy/cnpy.h"
|
#include "cnpy/cnpy.h"
|
||||||
#include "tensor.h"
|
#include "tensor.h"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loads model data stored in a npz file,
|
||||||
|
* enabling it to later be stored in standard Marian data structures.
|
||||||
|
*
|
||||||
|
* Note: this class makes use of the 3rd-party class <code>npy</code>.
|
||||||
|
*/
|
||||||
class NpzConverter {
|
class NpzConverter {
|
||||||
|
|
||||||
|
// Private inner classes of the NpzConverter class
|
||||||
private:
|
private:
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wraps npy data such that the underlying matrix shape and
|
||||||
|
* matrix data are made accessible.
|
||||||
|
*/
|
||||||
class NpyMatrixWrapper {
|
class NpyMatrixWrapper {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructs a wrapper around an underlying npy data structure,
|
||||||
|
* enabling the underlying data to be accessed as a matrix.
|
||||||
|
*
|
||||||
|
* @param npy the underlying data
|
||||||
|
*/
|
||||||
NpyMatrixWrapper(const cnpy::NpyArray& npy)
|
NpyMatrixWrapper(const cnpy::NpyArray& npy)
|
||||||
: npy_(npy) {}
|
: npy_(npy) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the total number of elements in the underlying matrix.
|
||||||
|
*
|
||||||
|
* @return the total number of elements in the underlying matrix
|
||||||
|
*/
|
||||||
size_t size() const {
|
size_t size() const {
|
||||||
return size1() * size2();
|
return size1() * size2();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a pointer to the raw data that underlies the matrix.
|
||||||
|
*
|
||||||
|
* @return a pointer to the raw data that underlies the matrix
|
||||||
|
*/
|
||||||
float* data() const {
|
float* data() const {
|
||||||
return (float*)npy_.data;
|
return (float*)npy_.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Given the index (i, j) of a matrix element,
|
||||||
|
* this operator returns the float value from the underlying npz data
|
||||||
|
* that is stored in the matrix.
|
||||||
|
*
|
||||||
|
* XXX: Marcin, is the following correct? Or do I have the row/column labels swapped?
|
||||||
|
*
|
||||||
|
* @param i Index of a column in the matrix
|
||||||
|
* @param j Index of a row in the matrix
|
||||||
|
*
|
||||||
|
* @return the float value stored at column i, row j of the matrix
|
||||||
|
*/
|
||||||
float operator()(size_t i, size_t j) const {
|
float operator()(size_t i, size_t j) const {
|
||||||
return ((float*)npy_.data)[i * size2() + j];
|
return ((float*)npy_.data)[i * size2() + j];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the number of columns in the matrix.
|
||||||
|
*
|
||||||
|
* XXX: Marcin, is this following correct? Or do I have the row/column labels swapped?
|
||||||
|
*
|
||||||
|
* @return the number of columns in the matrix
|
||||||
|
*/
|
||||||
size_t size1() const {
|
size_t size1() const {
|
||||||
return npy_.shape[0];
|
return npy_.shape[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the number of rows in the matrix.
|
||||||
|
*
|
||||||
|
* XXX: Marcin, is this following correct? Or do I have the row/column labels swapped?
|
||||||
|
*
|
||||||
|
* @return the number of rows in the matrix
|
||||||
|
*/
|
||||||
size_t size2() const {
|
size_t size2() const {
|
||||||
if(npy_.shape.size() == 1)
|
if(npy_.shape.size() == 1)
|
||||||
return 1;
|
return 1;
|
||||||
@ -34,25 +91,50 @@ class NpzConverter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const cnpy::NpyArray& npy_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
/** Instance of the underlying (3rd party) data structure. */
|
||||||
|
const cnpy::NpyArray& npy_;
|
||||||
|
|
||||||
|
}; // End of NpyMatrixWrapper class
|
||||||
|
|
||||||
|
// Public methods of the NpzConverter class
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructs an object that reads npz data from a file.
|
||||||
|
*
|
||||||
|
* @param file Path to file containing npz data
|
||||||
|
*/
|
||||||
NpzConverter(const std::string& file)
|
NpzConverter(const std::string& file)
|
||||||
: model_(cnpy::npz_load(file)),
|
: model_(cnpy::npz_load(file)),
|
||||||
destructed_(false) {
|
destructed_(false) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Destructs the model that underlies this NpzConverter object,
|
||||||
|
* if that data has not already been destructed.
|
||||||
|
*/
|
||||||
~NpzConverter() {
|
~NpzConverter() {
|
||||||
if(!destructed_)
|
if(!destructed_)
|
||||||
model_.destruct();
|
model_.destruct();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Destructs the model that underlies this NpzConverter object,
|
||||||
|
* and marks that data as having been destructed.
|
||||||
|
*/
|
||||||
void Destruct() {
|
void Destruct() {
|
||||||
model_.destruct();
|
model_.destruct();
|
||||||
destructed_ = true;
|
destructed_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loads data corresponding to a search key into the provided vector.
|
||||||
|
*
|
||||||
|
* @param key Search key XXX Marcin, what type of thing is "key"? What are we searching for here?
|
||||||
|
* @param data Container into which data will be loaded XXX Lane, is there a way in Doxygen to mark and inout variable?
|
||||||
|
* @param shape Shape object into which the number of rows and columns of the vectors will be stored
|
||||||
|
*/
|
||||||
void Load(const std::string& key, std::vector<float>& data, marian::Shape& shape) const {
|
void Load(const std::string& key, std::vector<float>& data, marian::Shape& shape) const {
|
||||||
auto it = model_.find(key);
|
auto it = model_.find(key);
|
||||||
if(it != model_.end()) {
|
if(it != model_.end()) {
|
||||||
@ -71,7 +153,13 @@ class NpzConverter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Private member data of the NpzConverter class
|
||||||
private:
|
private:
|
||||||
|
|
||||||
|
/** Underlying npz data */
|
||||||
cnpy::npz_t model_;
|
cnpy::npz_t model_;
|
||||||
|
|
||||||
|
/** Indicates whether the underlying data has been destructed. */
|
||||||
bool destructed_;
|
bool destructed_;
|
||||||
};
|
|
||||||
|
}; // End of NpzConverter class
|
||||||
|
Loading…
Reference in New Issue
Block a user