diff --git a/src/expression_graph.cu b/src/expression_graph.cu index b85c7223..cc0d9fb9 100644 --- a/src/expression_graph.cu +++ b/src/expression_graph.cu @@ -46,9 +46,9 @@ void temp() std::vector hVec({1,2, 4,3, 7,9, 7,3}); thrust::device_vector dVec(8); thrust::copy(hVec.begin(), hVec.end(), dVec.begin()); - float *data = hVec.data(); + float *data = thrust::raw_pointer_cast(dVec.data()); - gSoftMax<<<2, 4>>>(data, 2, 4); + gSoftMax<<<4, 2, sizeof(float)>>>(data, 4, 2); std::vector hVec2(8); thrust::copy(dVec.begin(), dVec.end(), hVec2.begin()); @@ -66,6 +66,7 @@ ExpressionGraph::ExpressionGraph(int cudaDevice) cudaSetDevice(0); temp(); + exit(0); } }