optimization

This commit is contained in:
Marcin Junczys-Dowmunt 2016-04-26 16:56:05 +02:00
parent 8dd0b6d3a8
commit 21e7c774c3
2 changed files with 17 additions and 10 deletions

View File

@ -2,8 +2,8 @@ cmake_minimum_required(VERSION 3.5.1)
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
project(amunn CXX) project(amunn CXX)
SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O0 -funroll-loops -Wno-unused-result -Wno-deprecated") SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O3 -funroll-loops -Wno-unused-result -Wno-deprecated")
LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O0; -arch=sm_35; -lineinfo; --use_fast_math;) LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O3; -arch=sm_35; -lineinfo; --use_fast_math;)
add_definitions(-DCUDA_API_PER_THREAD_DEFAULT_STREAM) add_definitions(-DCUDA_API_PER_THREAD_DEFAULT_STREAM)
SET(CUDA_PROPAGATE_HOST_FLAGS OFF) SET(CUDA_PROPAGATE_HOST_FLAGS OFF)

View File

@ -13,6 +13,12 @@ class Search {
using Matrix = typename Backend::Payload; using Matrix = typename Backend::Payload;
template <typename T>
using DeviceVector = typename Backend::DeviceVector<T>;
template <typename T>
using HostVector = typename Backend::HostVector<T>;
public: public:
Search(size_t threadId) Search(size_t threadId)
: scorers_(God::GetScorers(threadId)) {} : scorers_(God::GetScorers(threadId)) {}
@ -90,30 +96,31 @@ class Search {
Matrix& probs = probsEnsemble[0]; Matrix& probs = probsEnsemble[0];
Matrix costs(probs.Rows(), 1); Matrix costs;
(*costs).Resize((*probs).Rows(), 1);
HostVector<float> vCosts; HostVector<float> vCosts;
for(auto& h : prevHyps) for(auto& h : prevHyps)
vCosts.push_back(h->GetCost()); vCosts.push_back(h->GetCost());
Backend::copy(vCosts.begin(), vCosts.end(), costs.begin()); Backend::copy(vCosts.begin(), vCosts.end(), costs.begin());
Backend::BroadcastVecColumn(weights[0] * Backend::_1 + Backend::_2, Backend::Broadcast(weights[0] * Backend::_1 + Backend::_2,
probs, costs); probs, costs);
for(size_t i = 0; i < probsEnsemble.size(); ++i) for(size_t i = 0; i < probsEnsemble.size(); ++i)
Backend::Element(Backend::_1 + weights[i] * Backend::_2, Backend::Element(Backend::_1 + weights[i] * Backend::_2,
probs, probsEnsemble[i]); probs, probsEnsemble[i]);
Backend::HostVector<unsigned> bestKeys(beamSize); HostVector<unsigned> bestKeys(beamSize);
Backend::HostVector<float> bestCosts(beamSize); HostVector<float> bestCosts(beamSize);
Backend::PartialSortByKey(probs, bestKeys, bestCosts); Backend::PartialSortByKey(probs, bestKeys, bestCosts);
std::vector<Backend::HostVector<float>> breakDowns; std::vector<HostVector<float>> breakDowns;
bool doBreakdown = God::Get<bool>("n-best"); bool doBreakdown = God::Get<bool>("n-best");
if(doBreakdown) { if(doBreakdown) {
breakDowns.push_back(bestCosts); breakDowns.push_back(bestCosts);
for(size_t i = 1; i < probsEnsemble.size(); ++i) { for(size_t i = 1; i < probsEnsemble.size(); ++i) {
HostVector<float> modelCosts(beamSize); HostVector<float> modelCosts(beamSize);
auto it = Backend::make_permutation_iterator(probsEnsemble[i].begin(), keys.begin()); auto it = Backend::make_permutation_iterator(probsEnsemble[i].begin(), bestKeys.begin());
Backend::copy(it, it + beamSize, modelCosts.begin()); Backend::copy(it, it + beamSize, modelCosts.begin());
breakDowns.push_back(modelCosts); breakDowns.push_back(modelCosts);
} }
@ -136,7 +143,7 @@ class Search {
float cost = 0; float cost = 0;
if(j < probsEnsemble.size()) { if(j < probsEnsemble.size()) {
if(prevHyps[hypIndex]->GetCostBreakdown().size() < probsEnsemble.size()) if(prevHyps[hypIndex]->GetCostBreakdown().size() < probsEnsemble.size())
const_cast<HypothesisPtr&>(prevHyps[hypIndex])->GetCostBreakdown().resize(ProbsEnsemble.size(), 0.0); const_cast<HypothesisPtr&>(prevHyps[hypIndex])->GetCostBreakdown().resize(probsEnsemble.size(), 0.0);
cost = breakDowns[j][i] + const_cast<HypothesisPtr&>(prevHyps[hypIndex])->GetCostBreakdown()[j]; cost = breakDowns[j][i] + const_cast<HypothesisPtr&>(prevHyps[hypIndex])->GetCostBreakdown()[j];
} }
sum += weights[j] * cost; sum += weights[j] * cost;