mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-30 21:39:52 +03:00
more work on backends
This commit is contained in:
parent
2d923f8017
commit
8dd0b6d3a8
54
src/decoder/backend_gpu.h
Normal file
54
src/decoder/backend_gpu.h
Normal file
@ -0,0 +1,54 @@
|
||||
#pragma once
|
||||
|
||||
#include "matrix.h"
|
||||
|
||||
class BackendGPU /* : public Backend */ {
|
||||
public:
|
||||
template <typename T>
|
||||
using DeviceVector = thrust::device_vector<T>;
|
||||
|
||||
template <typename T>
|
||||
using HostVector = thrust::host_vector<T>;
|
||||
|
||||
class Payload : public PayloadBase {
|
||||
public:
|
||||
Payload() {}
|
||||
|
||||
mblas::Matrix& operator*() {
|
||||
return matrix_;
|
||||
}
|
||||
|
||||
private:
|
||||
mblas::Matrix matrix_;
|
||||
};
|
||||
|
||||
|
||||
static void PartialSortByKey(Payload& probs,
|
||||
HostVector<unsigned>& bestKeys,
|
||||
HostVector<float>& bestCosts) {
|
||||
size_t beamSize = bestKeys.size();
|
||||
if(beamSize < 10) {
|
||||
for(size_t i = 0; i < beamSize; ++i) {
|
||||
DeviceVector<float>::iterator iter =
|
||||
thrust::max_element((*probs).begin(), (*probs).end());
|
||||
bestKeys[i] = iter - (*probs).begin();
|
||||
bestCosts[i] = *iter;
|
||||
*iter = std::numeric_limits<float>::lowest();
|
||||
}
|
||||
}
|
||||
else {
|
||||
DeviceVector<unsigned> keys((*probs).size());
|
||||
thrust::sequence(keys.begin(), keys.end());
|
||||
thrust::sort_by_key((*probs).begin(), (*probs).end(),
|
||||
keys.begin(), thrust::greater<float>());
|
||||
|
||||
thrust::copy_n(keys.begin(), beamSize, bestKeys.begin());
|
||||
thrust::copy_n((*probs).begin(), beamSize, bestCosts.begin());
|
||||
}
|
||||
}
|
||||
|
||||
template <class It1, class It2>
|
||||
static void copy(It1 begin, It1 end, It2 out) {
|
||||
thrust::copy(begin, end, out);
|
||||
}
|
||||
};
|
17
src/decoder/sentence.cpp
Normal file
17
src/decoder/sentence.cpp
Normal file
@ -0,0 +1,17 @@
|
||||
#include "sentence.h"
|
||||
#include "god.h"
|
||||
|
||||
Sentence::Sentence(size_t lineNo, const std::string& line)
|
||||
: lineNo_(lineNo), line_(line)
|
||||
{
|
||||
words_.push_back(God::GetSourceVocab(0)(line));
|
||||
}
|
||||
|
||||
const Words& Sentence::GetWords(size_t index) const {
|
||||
return words_[index];
|
||||
}
|
||||
|
||||
size_t Sentence::GetLine() const {
|
||||
return lineNo_;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user