1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one or more
3
+ # contributor license agreements. See the NOTICE file distributed with
4
+ # this work for additional information regarding copyright ownership.
5
+ # The ASF licenses this file to You under the Apache License, Version 2.0
6
+ # (the "License"); you may not use this file except in compliance with
7
+ # the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ import typing
18
+ from typing import List
19
+ from typing import Optional
20
+
21
+ from apache_beam .portability .api import beam_fn_api_pb2
22
+ from apache_beam .portability .api import beam_runner_api_pb2
23
+ from apache_beam .portability .api import endpoints_pb2
24
+ from apache_beam .runners .portability .fn_api_runner import execution as fn_execution
25
+ from apache_beam .runners .portability .fn_api_runner import translations
26
+ from apache_beam .runners .portability .fn_api_runner import worker_handlers
27
+ from apache_beam .runners .worker import bundle_processor
28
+ from apache_beam .utils import proto_utils
29
+
30
+ import ray
31
+ from ray_beam_runner .portability .execution import RayRunnerExecutionContext
32
+
33
+ class RayBundleContextManager :
34
+
35
+ def __init__ (self ,
36
+ execution_context : RayRunnerExecutionContext ,
37
+ stage : translations .Stage ,
38
+ ) -> None :
39
+ self .execution_context = execution_context
40
+ self .stage = stage
41
+ # self.extract_bundle_inputs_and_outputs()
42
+ self .bundle_uid = self .execution_context .next_uid ()
43
+
44
+ # Properties that are lazily initialized
45
+ self ._process_bundle_descriptor = None # type: Optional[beam_fn_api_pb2.ProcessBundleDescriptor]
46
+ self ._worker_handlers = None # type: Optional[List[worker_handlers.WorkerHandler]]
47
+ # a mapping of {(transform_id, timer_family_id): timer_coder_id}. The map
48
+ # is built after self._process_bundle_descriptor is initialized.
49
+ # This field can be used to tell whether current bundle has timers.
50
+ self ._timer_coder_ids = None # type: Optional[Dict[Tuple[str, str], str]]
51
+
52
+ def __reduce__ (self ):
53
+ data = (self .execution_context ,
54
+ self .stage )
55
+ deserializer = lambda args : RayBundleContextManager (args [0 ], args [1 ])
56
+ return (deserializer , data )
57
+
58
+ @property
59
+ def worker_handlers (self ) -> List [worker_handlers .WorkerHandler ]:
60
+ return []
61
+
62
+ def data_api_service_descriptor (self ) -> Optional [endpoints_pb2 .ApiServiceDescriptor ]:
63
+ return endpoints_pb2 .ApiServiceDescriptor (url = 'fake' )
64
+
65
+ def state_api_service_descriptor (self ) -> Optional [endpoints_pb2 .ApiServiceDescriptor ]:
66
+ return None
67
+
68
+ @property
69
+ def process_bundle_descriptor (self ):
70
+ # type: () -> beam_fn_api_pb2.ProcessBundleDescriptor
71
+ if self ._process_bundle_descriptor is None :
72
+ self ._process_bundle_descriptor = beam_fn_api_pb2 .ProcessBundleDescriptor .FromString (
73
+ self ._build_process_bundle_descriptor ())
74
+ self ._timer_coder_ids = fn_execution .BundleContextManager ._build_timer_coders_id_map (self )
75
+ return self ._process_bundle_descriptor
76
+
77
+ def _build_process_bundle_descriptor (self ):
78
+ # Cannot be invoked until *after* _extract_endpoints is called.
79
+ # Always populate the timer_api_service_descriptor.
80
+ pbd = beam_fn_api_pb2 .ProcessBundleDescriptor (
81
+ id = self .bundle_uid ,
82
+ transforms = {
83
+ transform .unique_name : transform
84
+ for transform in self .stage .transforms
85
+ },
86
+ pcollections = dict (
87
+ self .execution_context .pipeline_components .pcollections .items ()),
88
+ coders = dict (self .execution_context .pipeline_components .coders .items ()),
89
+ windowing_strategies = dict (
90
+ self .execution_context .pipeline_components .windowing_strategies .
91
+ items ()),
92
+ environments = dict (
93
+ self .execution_context .pipeline_components .environments .items ()),
94
+ state_api_service_descriptor = self .state_api_service_descriptor (),
95
+ timer_api_service_descriptor = self .data_api_service_descriptor ())
96
+
97
+ return pbd .SerializeToString ()
98
+
99
+ def extract_bundle_inputs_and_outputs (self ):
100
+ # type: () -> Tuple[Dict[str, PartitionableBuffer], DataOutput, Dict[TimerFamilyId, bytes]]
101
+
102
+ """Returns maps of transform names to PCollection identifiers.
103
+
104
+ Also mutates IO stages to point to the data ApiServiceDescriptor.
105
+
106
+ Returns:
107
+ A tuple of (data_input, data_output, expected_timer_output) dictionaries.
108
+ `data_input` is a dictionary mapping (transform_name, output_name) to a
109
+ PCollection buffer; `data_output` is a dictionary mapping
110
+ (transform_name, output_name) to a PCollection ID.
111
+ `expected_timer_output` is a dictionary mapping transform_id and
112
+ timer family ID to a buffer id for timers.
113
+ """
114
+ transform_to_buffer_coder : typing .Dict [str , typing .Tuple [bytes , str ]] = {}
115
+ data_output = {} # type: DataOutput
116
+ expected_timer_output = {} # type: OutputTimers
117
+ for transform in self .stage .transforms :
118
+ if transform .spec .urn in (bundle_processor .DATA_INPUT_URN ,
119
+ bundle_processor .DATA_OUTPUT_URN ):
120
+ pcoll_id = transform .spec .payload
121
+ if transform .spec .urn == bundle_processor .DATA_INPUT_URN :
122
+ coder_id = self .execution_context .data_channel_coders [translations .only_element (
123
+ transform .outputs .values ())]
124
+ if pcoll_id == translations .IMPULSE_BUFFER :
125
+ buffer_actor = ray .get (self .execution_context .pcollection_buffers .get .remote (
126
+ transform .unique_name ))
127
+ ray .get (buffer_actor .append .remote (fn_execution .ENCODED_IMPULSE_VALUE ))
128
+ pcoll_id = transform .unique_name .encode ('utf8' )
129
+ else :
130
+ pass
131
+ transform_to_buffer_coder [transform .unique_name ] = (
132
+ pcoll_id ,
133
+ self .execution_context .safe_coders .get (coder_id , coder_id )
134
+ )
135
+ elif transform .spec .urn == bundle_processor .DATA_OUTPUT_URN :
136
+ data_output [transform .unique_name ] = pcoll_id
137
+ coder_id = self .execution_context .data_channel_coders [translations .only_element (
138
+ transform .inputs .values ())]
139
+ else :
140
+ raise NotImplementedError
141
+ # TODO(pabloem): Figure out when we DO and we DONT need this particular rewrite of coders.
142
+ data_spec = beam_fn_api_pb2 .RemoteGrpcPort (coder_id = coder_id )
143
+ # data_spec.api_service_descriptor.url = 'fake'
144
+ transform .spec .payload = data_spec .SerializeToString ()
145
+ elif transform .spec .urn in translations .PAR_DO_URNS :
146
+ payload = proto_utils .parse_Bytes (
147
+ transform .spec .payload , beam_runner_api_pb2 .ParDoPayload )
148
+ for timer_family_id in payload .timer_family_specs .keys ():
149
+ expected_timer_output [(transform .unique_name , timer_family_id )] = (
150
+ translations .create_buffer_id (timer_family_id , 'timers' ))
151
+ return transform_to_buffer_coder , data_output , expected_timer_output
0 commit comments