Skip to content

Commit 377ffbd

Browse files
pabloempdames
andauthored
Initial Ray runner prototype based on FnApiRunner in Beam (#10)
* Adding the simplest state handler * fixup * Intermediate state: Working on Ray Beam Runner * Initial prototype of Ray Beam runner runs Batch pipelines * Suggested changes from @pdames Co-authored-by: Patrick Ames <[email protected]>
1 parent 9a084a3 commit 377ffbd

File tree

8 files changed

+3233
-1
lines changed

8 files changed

+3233
-1
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
import typing
18+
from typing import List
19+
from typing import Optional
20+
21+
from apache_beam.portability.api import beam_fn_api_pb2
22+
from apache_beam.portability.api import beam_runner_api_pb2
23+
from apache_beam.portability.api import endpoints_pb2
24+
from apache_beam.runners.portability.fn_api_runner import execution as fn_execution
25+
from apache_beam.runners.portability.fn_api_runner import translations
26+
from apache_beam.runners.portability.fn_api_runner import worker_handlers
27+
from apache_beam.runners.worker import bundle_processor
28+
from apache_beam.utils import proto_utils
29+
30+
import ray
31+
from ray_beam_runner.portability.execution import RayRunnerExecutionContext
32+
33+
class RayBundleContextManager:
34+
35+
def __init__(self,
36+
execution_context: RayRunnerExecutionContext,
37+
stage: translations.Stage,
38+
) -> None:
39+
self.execution_context = execution_context
40+
self.stage = stage
41+
# self.extract_bundle_inputs_and_outputs()
42+
self.bundle_uid = self.execution_context.next_uid()
43+
44+
# Properties that are lazily initialized
45+
self._process_bundle_descriptor = None # type: Optional[beam_fn_api_pb2.ProcessBundleDescriptor]
46+
self._worker_handlers = None # type: Optional[List[worker_handlers.WorkerHandler]]
47+
# a mapping of {(transform_id, timer_family_id): timer_coder_id}. The map
48+
# is built after self._process_bundle_descriptor is initialized.
49+
# This field can be used to tell whether current bundle has timers.
50+
self._timer_coder_ids = None # type: Optional[Dict[Tuple[str, str], str]]
51+
52+
def __reduce__(self):
53+
data = (self.execution_context,
54+
self.stage)
55+
deserializer = lambda args: RayBundleContextManager(args[0], args[1])
56+
return (deserializer, data)
57+
58+
@property
59+
def worker_handlers(self) -> List[worker_handlers.WorkerHandler]:
60+
return []
61+
62+
def data_api_service_descriptor(self) -> Optional[endpoints_pb2.ApiServiceDescriptor]:
63+
return endpoints_pb2.ApiServiceDescriptor(url='fake')
64+
65+
def state_api_service_descriptor(self) -> Optional[endpoints_pb2.ApiServiceDescriptor]:
66+
return None
67+
68+
@property
69+
def process_bundle_descriptor(self):
70+
# type: () -> beam_fn_api_pb2.ProcessBundleDescriptor
71+
if self._process_bundle_descriptor is None:
72+
self._process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor.FromString(
73+
self._build_process_bundle_descriptor())
74+
self._timer_coder_ids = fn_execution.BundleContextManager._build_timer_coders_id_map(self)
75+
return self._process_bundle_descriptor
76+
77+
def _build_process_bundle_descriptor(self):
78+
# Cannot be invoked until *after* _extract_endpoints is called.
79+
# Always populate the timer_api_service_descriptor.
80+
pbd = beam_fn_api_pb2.ProcessBundleDescriptor(
81+
id=self.bundle_uid,
82+
transforms={
83+
transform.unique_name: transform
84+
for transform in self.stage.transforms
85+
},
86+
pcollections=dict(
87+
self.execution_context.pipeline_components.pcollections.items()),
88+
coders=dict(self.execution_context.pipeline_components.coders.items()),
89+
windowing_strategies=dict(
90+
self.execution_context.pipeline_components.windowing_strategies.
91+
items()),
92+
environments=dict(
93+
self.execution_context.pipeline_components.environments.items()),
94+
state_api_service_descriptor=self.state_api_service_descriptor(),
95+
timer_api_service_descriptor=self.data_api_service_descriptor())
96+
97+
return pbd.SerializeToString()
98+
99+
def extract_bundle_inputs_and_outputs(self):
100+
# type: () -> Tuple[Dict[str, PartitionableBuffer], DataOutput, Dict[TimerFamilyId, bytes]]
101+
102+
"""Returns maps of transform names to PCollection identifiers.
103+
104+
Also mutates IO stages to point to the data ApiServiceDescriptor.
105+
106+
Returns:
107+
A tuple of (data_input, data_output, expected_timer_output) dictionaries.
108+
`data_input` is a dictionary mapping (transform_name, output_name) to a
109+
PCollection buffer; `data_output` is a dictionary mapping
110+
(transform_name, output_name) to a PCollection ID.
111+
`expected_timer_output` is a dictionary mapping transform_id and
112+
timer family ID to a buffer id for timers.
113+
"""
114+
transform_to_buffer_coder: typing.Dict[str, typing.Tuple[bytes, str]] = {}
115+
data_output = {} # type: DataOutput
116+
expected_timer_output = {} # type: OutputTimers
117+
for transform in self.stage.transforms:
118+
if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
119+
bundle_processor.DATA_OUTPUT_URN):
120+
pcoll_id = transform.spec.payload
121+
if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
122+
coder_id = self.execution_context.data_channel_coders[translations.only_element(
123+
transform.outputs.values())]
124+
if pcoll_id == translations.IMPULSE_BUFFER:
125+
buffer_actor = ray.get(self.execution_context.pcollection_buffers.get.remote(
126+
transform.unique_name))
127+
ray.get(buffer_actor.append.remote(fn_execution.ENCODED_IMPULSE_VALUE))
128+
pcoll_id = transform.unique_name.encode('utf8')
129+
else:
130+
pass
131+
transform_to_buffer_coder[transform.unique_name] = (
132+
pcoll_id,
133+
self.execution_context.safe_coders.get(coder_id, coder_id)
134+
)
135+
elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
136+
data_output[transform.unique_name] = pcoll_id
137+
coder_id = self.execution_context.data_channel_coders[translations.only_element(
138+
transform.inputs.values())]
139+
else:
140+
raise NotImplementedError
141+
# TODO(pabloem): Figure out when we DO and we DONT need this particular rewrite of coders.
142+
data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
143+
# data_spec.api_service_descriptor.url = 'fake'
144+
transform.spec.payload = data_spec.SerializeToString()
145+
elif transform.spec.urn in translations.PAR_DO_URNS:
146+
payload = proto_utils.parse_Bytes(
147+
transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
148+
for timer_family_id in payload.timer_family_specs.keys():
149+
expected_timer_output[(transform.unique_name, timer_family_id)] = (
150+
translations.create_buffer_id(timer_family_id, 'timers'))
151+
return transform_to_buffer_coder, data_output, expected_timer_output

0 commit comments

Comments
 (0)