mirror of
https://github.com/marian-nmt/marian.git
synced 2024-12-01 05:50:03 +03:00
optimization
This commit is contained in:
parent
8dd0b6d3a8
commit
21e7c774c3
@ -2,8 +2,8 @@ cmake_minimum_required(VERSION 3.5.1)
|
|||||||
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
|
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
|
||||||
|
|
||||||
project(amunn CXX)
|
project(amunn CXX)
|
||||||
SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O0 -funroll-loops -Wno-unused-result -Wno-deprecated")
|
SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O3 -funroll-loops -Wno-unused-result -Wno-deprecated")
|
||||||
LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O0; -arch=sm_35; -lineinfo; --use_fast_math;)
|
LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O3; -arch=sm_35; -lineinfo; --use_fast_math;)
|
||||||
add_definitions(-DCUDA_API_PER_THREAD_DEFAULT_STREAM)
|
add_definitions(-DCUDA_API_PER_THREAD_DEFAULT_STREAM)
|
||||||
SET(CUDA_PROPAGATE_HOST_FLAGS OFF)
|
SET(CUDA_PROPAGATE_HOST_FLAGS OFF)
|
||||||
|
|
||||||
|
@ -13,6 +13,12 @@ class Search {
|
|||||||
|
|
||||||
using Matrix = typename Backend::Payload;
|
using Matrix = typename Backend::Payload;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using DeviceVector = typename Backend::DeviceVector<T>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using HostVector = typename Backend::HostVector<T>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Search(size_t threadId)
|
Search(size_t threadId)
|
||||||
: scorers_(God::GetScorers(threadId)) {}
|
: scorers_(God::GetScorers(threadId)) {}
|
||||||
@ -90,30 +96,31 @@ class Search {
|
|||||||
|
|
||||||
Matrix& probs = probsEnsemble[0];
|
Matrix& probs = probsEnsemble[0];
|
||||||
|
|
||||||
Matrix costs(probs.Rows(), 1);
|
Matrix costs;
|
||||||
|
(*costs).Resize((*probs).Rows(), 1);
|
||||||
HostVector<float> vCosts;
|
HostVector<float> vCosts;
|
||||||
for(auto& h : prevHyps)
|
for(auto& h : prevHyps)
|
||||||
vCosts.push_back(h->GetCost());
|
vCosts.push_back(h->GetCost());
|
||||||
Backend::copy(vCosts.begin(), vCosts.end(), costs.begin());
|
Backend::copy(vCosts.begin(), vCosts.end(), costs.begin());
|
||||||
|
|
||||||
Backend::BroadcastVecColumn(weights[0] * Backend::_1 + Backend::_2,
|
Backend::Broadcast(weights[0] * Backend::_1 + Backend::_2,
|
||||||
probs, costs);
|
probs, costs);
|
||||||
for(size_t i = 0; i < probsEnsemble.size(); ++i)
|
for(size_t i = 0; i < probsEnsemble.size(); ++i)
|
||||||
Backend::Element(Backend::_1 + weights[i] * Backend::_2,
|
Backend::Element(Backend::_1 + weights[i] * Backend::_2,
|
||||||
probs, probsEnsemble[i]);
|
probs, probsEnsemble[i]);
|
||||||
|
|
||||||
Backend::HostVector<unsigned> bestKeys(beamSize);
|
HostVector<unsigned> bestKeys(beamSize);
|
||||||
Backend::HostVector<float> bestCosts(beamSize);
|
HostVector<float> bestCosts(beamSize);
|
||||||
|
|
||||||
Backend::PartialSortByKey(probs, bestKeys, bestCosts);
|
Backend::PartialSortByKey(probs, bestKeys, bestCosts);
|
||||||
|
|
||||||
std::vector<Backend::HostVector<float>> breakDowns;
|
std::vector<HostVector<float>> breakDowns;
|
||||||
bool doBreakdown = God::Get<bool>("n-best");
|
bool doBreakdown = God::Get<bool>("n-best");
|
||||||
if(doBreakdown) {
|
if(doBreakdown) {
|
||||||
breakDowns.push_back(bestCosts);
|
breakDowns.push_back(bestCosts);
|
||||||
for(size_t i = 1; i < probsEnsemble.size(); ++i) {
|
for(size_t i = 1; i < probsEnsemble.size(); ++i) {
|
||||||
HostVector<float> modelCosts(beamSize);
|
HostVector<float> modelCosts(beamSize);
|
||||||
auto it = Backend::make_permutation_iterator(probsEnsemble[i].begin(), keys.begin());
|
auto it = Backend::make_permutation_iterator(probsEnsemble[i].begin(), bestKeys.begin());
|
||||||
Backend::copy(it, it + beamSize, modelCosts.begin());
|
Backend::copy(it, it + beamSize, modelCosts.begin());
|
||||||
breakDowns.push_back(modelCosts);
|
breakDowns.push_back(modelCosts);
|
||||||
}
|
}
|
||||||
@ -136,7 +143,7 @@ class Search {
|
|||||||
float cost = 0;
|
float cost = 0;
|
||||||
if(j < probsEnsemble.size()) {
|
if(j < probsEnsemble.size()) {
|
||||||
if(prevHyps[hypIndex]->GetCostBreakdown().size() < probsEnsemble.size())
|
if(prevHyps[hypIndex]->GetCostBreakdown().size() < probsEnsemble.size())
|
||||||
const_cast<HypothesisPtr&>(prevHyps[hypIndex])->GetCostBreakdown().resize(ProbsEnsemble.size(), 0.0);
|
const_cast<HypothesisPtr&>(prevHyps[hypIndex])->GetCostBreakdown().resize(probsEnsemble.size(), 0.0);
|
||||||
cost = breakDowns[j][i] + const_cast<HypothesisPtr&>(prevHyps[hypIndex])->GetCostBreakdown()[j];
|
cost = breakDowns[j][i] + const_cast<HypothesisPtr&>(prevHyps[hypIndex])->GetCostBreakdown()[j];
|
||||||
}
|
}
|
||||||
sum += weights[j] * cost;
|
sum += weights[j] * cost;
|
||||||
|
Loading…
Reference in New Issue
Block a user