Skip to content

Commit

Permalink
add CUDA roundtrip tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jorisvandenbossche committed Feb 15, 2024
1 parent 1f79861 commit 02cff8c
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 1 deletion.
13 changes: 12 additions & 1 deletion cpp/src/arrow/c/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,18 @@
# specific language governing permissions and limitations
# under the License.

add_arrow_test(bridge_test PREFIX "arrow-c")
if(ARROW_CUDA)
if(ARROW_BUILD_SHARED)
set(ARROW_CUDA_LIBRARY arrow_cuda_shared)
else()
set(ARROW_CUDA_LIBRARY arrow_cuda_static)
endif()
set(ARROW_CUDA_TEST_LINK_LIBS ${ARROW_CUDA_LIBRARY} ${ARROW_TEST_LINK_LIBS})
add_arrow_test(bridge_test PREFIX "arrow-c" STATIC_LINK_LIBS ${ARROW_CUDA_TEST_LINK_LIBS})
else()
add_arrow_test(bridge_test PREFIX "arrow-c")
endif()

add_arrow_test(dlpack_test)

add_arrow_benchmark(bridge_benchmark)
Expand Down
115 changes: 115 additions & 0 deletions cpp/src/arrow/c/bridge_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@
#include "arrow/compute/api_vector.h"
#endif

#ifdef ARROW_CUDA
#include "arrow/gpu/cuda_api.h"
#endif

namespace arrow {

using internal::ArrayExportGuard;
Expand Down Expand Up @@ -4330,6 +4334,117 @@ TEST_F(TestDeviceArrayRoundtrip, Struct) {
TestWithJSON(mm, type, R"([[4, null], null, [5, "foo"]])");
}

#ifdef ARROW_CUDA

class TestCUDADeviceArrayRoundtrip : public ::testing::Test {
public:
using ArrayFactory = std::function<Result<std::shared_ptr<Array>>()>;

static Result<std::shared_ptr<MemoryManager>> DeviceMapper(ArrowDeviceType type,
int64_t id) {
if (type != ARROW_DEVICE_CUDA) {
return Status::NotImplemented("should only be CUDA device");
}

ARROW_ASSIGN_OR_RAISE(auto manager, cuda::CudaDeviceManager::Instance());
ARROW_ASSIGN_OR_RAISE(auto device, manager->GetDevice(id));
return device->default_memory_manager();
}

static ArrayFactory JSONArrayFactory(std::shared_ptr<DataType> type, const char* json) {
return [=]() { return ArrayFromJSON(type, json); };
}

template <typename ArrayFactory>
void TestWithArrayFactory(ArrayFactory&& factory) {
TestWithArrayFactory(factory, factory);
}

template <typename ArrayFactory, typename ExpectedArrayFactory>
void TestWithArrayFactory(ArrayFactory&& factory,
ExpectedArrayFactory&& factory_expected) {
ASSERT_OK_AND_ASSIGN(auto manager, cuda::CudaDeviceManager::Instance());
ASSERT_OK_AND_ASSIGN(auto device, manager->GetDevice(0));
auto mm = device->default_memory_manager();

std::shared_ptr<Array> array;
std::shared_ptr<Array> device_array;
ASSERT_OK_AND_ASSIGN(array, factory());
ASSERT_OK_AND_ASSIGN(device_array, array->CopyTo(mm));

struct ArrowDeviceArray c_array {};
struct ArrowSchema c_schema {};
ArrayExportGuard array_guard(&c_array.array);
SchemaExportGuard schema_guard(&c_schema);

ASSERT_OK(ExportType(*device_array->type(), &c_schema));
std::shared_ptr<Device::SyncEvent> sync{nullptr};
ASSERT_OK(ExportDeviceArray(*device_array, sync, &c_array));

std::shared_ptr<Array> device_array_roundtripped;
ASSERT_OK_AND_ASSIGN(device_array_roundtripped,
ImportDeviceArray(&c_array, &c_schema, DeviceMapper));
ASSERT_TRUE(ArrowSchemaIsReleased(&c_schema));
ASSERT_TRUE(ArrowArrayIsReleased(&c_array.array));

// Check value of imported array (copy to CPU to assert equality)
std::shared_ptr<Array> array_roundtripped;
ASSERT_OK_AND_ASSIGN(array_roundtripped,
device_array_roundtripped->CopyTo(default_cpu_memory_manager()));
{
std::shared_ptr<Array> expected;
ASSERT_OK_AND_ASSIGN(expected, factory_expected());
AssertTypeEqual(*expected->type(), *array_roundtripped->type());
AssertArraysEqual(*expected, *array_roundtripped, true);
}

// Re-export and re-import, now both at once
ASSERT_OK(ExportDeviceArray(*device_array, sync, &c_array, &c_schema));
device_array_roundtripped.reset();
ASSERT_OK_AND_ASSIGN(device_array_roundtripped,
ImportDeviceArray(&c_array, &c_schema, DeviceMapper));
ASSERT_TRUE(ArrowSchemaIsReleased(&c_schema));
ASSERT_TRUE(ArrowArrayIsReleased(&c_array.array));

// Check value of imported array (copy to CPU to assert equality)
array_roundtripped.reset();
ASSERT_OK_AND_ASSIGN(array_roundtripped,
device_array_roundtripped->CopyTo(default_cpu_memory_manager()));
{
std::shared_ptr<Array> expected;
ASSERT_OK_AND_ASSIGN(expected, factory_expected());
AssertTypeEqual(*expected->type(), *array_roundtripped->type());
AssertArraysEqual(*expected, *array_roundtripped, true);
}
}

void TestWithJSON(std::shared_ptr<DataType> type, const char* json) {
TestWithArrayFactory(JSONArrayFactory(type, json));
}
};

TEST_F(TestCUDADeviceArrayRoundtrip, Primitive) { TestWithJSON(int32(), "[4, 5, null]"); }

TEST_F(TestCUDADeviceArrayRoundtrip, Struct) {
auto type = struct_({field("ints", int16()), field("strs", utf8())});

TestWithJSON(type, "[]");
TestWithJSON(type, R"([[4, "foo"], [5, "bar"]])");
TestWithJSON(type, R"([[4, null], null, [5, "foo"]])");
}

TEST_F(TestCUDADeviceArrayRoundtrip, Dictionary) {
auto factory = []() {
auto values = ArrayFromJSON(utf8(), R"(["foo", "bar", "quux"])");
auto indices = ArrayFromJSON(uint16(), "[0, 2, 1, null, 1]");
return DictionaryArray::FromArrays(dictionary(indices->type(), values->type()),
indices, values);
};
TestWithArrayFactory(factory);
}

#endif

////////////////////////////////////////////////////////////////////////////
// Array stream export tests

Expand Down

0 comments on commit 02cff8c

Please sign in to comment.