Skip to content

Commit

Permalink
apacheGH-41884: [Python] Fix RecordBatchReader.cast to support castin…
Browse files Browse the repository at this point in the history
…g to equal schema for all types
  • Loading branch information
jorisvandenbossche committed Jun 11, 2024
1 parent fe4d04f commit ee6dbf7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
3 changes: 2 additions & 1 deletion python/pyarrow/src/arrow/python/ipc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ Status CastingRecordBatchReader::Init(std::shared_ptr<RecordBatchReader> parent,

// Ensure all columns can be cast before succeeding
for (int i = 0; i < num_fields; i++) {
if (!compute::CanCast(*src->field(i)->type(), *schema->field(i)->type())) {
if ((!((src->field(i)->type()->Equals(schema->field(i)->type())))) &&
(!compute::CanCast(*src->field(i)->type(), *schema->field(i)->type()))) {
return Status::TypeError("Field ", i, " cannot be cast from ",
src->field(i)->type()->ToString(), " to ",
schema->field(i)->type()->ToString());
Expand Down
10 changes: 10 additions & 0 deletions python/pyarrow/tests/test_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

from collections import UserList
import datetime
import io
import pathlib
import pytest
Expand Down Expand Up @@ -1272,6 +1273,15 @@ def test_record_batch_reader_cast():
with pytest.raises(pa.lib.ArrowTypeError, match='Field 0 cannot be cast'):
reader.cast(pa.schema([pa.field('a', pa.list_(pa.int32()))]))

# Cast to same type should always work also for date32
# (https://github.com/apache/arrow/issues/41884)
schema_src = pa.schema([pa.field('a', pa.date32())])
arr = pa.array([datetime.date(2024, 6, 11)], type=pa.date32())
data = [pa.record_batch([arr], names=['a']), pa.record_batch([arr], names=['a'])]
table_src = pa.Table.from_batches(data)
reader = pa.RecordBatchReader.from_batches(schema_src, data)
assert reader.cast(schema_src).read_all() == table_src


def test_record_batch_reader_cast_nulls():
schema_src = pa.schema([pa.field('a', pa.int64())])
Expand Down

0 comments on commit ee6dbf7

Please sign in to comment.