@@ -145,7 +145,7 @@ def decompile_ptx(asm_path, ptx_path, define_list):
145
145
print (ptx , end = '' )
146
146
147
147
148
- def assemble (asm_path , out_cubin_path , define_list , out_asm_path , sort_banks ):
148
+ def assemble (asm_path , out_cubin_path , define_list , out_asm_path , sort_banks , strip ):
149
149
if not out_cubin_path :
150
150
out_cubin_path = 'out.cubin'
151
151
cubin = Cubin ()
@@ -202,69 +202,12 @@ def assemble(asm_path, out_cubin_path, define_list, out_asm_path, sort_banks):
202
202
for constant in cubin .constant_dict .values ():
203
203
constant_asm += constant .print () + '\n '
204
204
asm = header_asm + global_asm + constant_asm + kernel_asm
205
- with open (out_asm_path , 'w' ) as f :
206
- f .write (asm )
207
-
208
-
209
- def preprocess (asm_path , out_asm_path , define_list , strip ):
210
- cubin = Cubin ()
211
-
212
- define_dict = {}
213
- for define in define_list :
214
- if not define :
215
- continue
216
- d = define .split ('=' )
217
- if len (d ) < 2 :
218
- exec (f'{ define } = True' , define_dict )
219
- else :
220
- exec (f'{ define } ' , define_dict )
221
205
222
- cubin .load_asm (asm_path , define_dict )
223
- header_asm = cubin .header .print () + '\n '
224
- global_asm = ''
225
- constant_asm = ''
226
- kernel_asm = ''
206
+ if strip :
207
+ asm = strip_comment (asm )
227
208
228
- consts = set ()
229
- globals_ = set ()
230
- for kernel in cubin .kernel_dict .values ():
231
- # Unmap global, const0, const3
232
- cubin .unmap_constant3 (kernel )
233
- kernel .unmap_reg ()
234
- kernel .unmap_constant0 ()
235
- kernel .unmap_jump ()
236
- kernel .unmap_global ()
237
- kernel .schedule ()
238
- kernel .sort_banks ()
239
- cubin .map_constant3 (kernel )
240
- kernel .map_global ()
241
- kernel .map_constant0 ()
242
- kernel .map_jump (rel = True )
243
- kernel .mark_const2 ()
244
- kernel_asm += '\n ' + kernel .print ()
245
- consts = consts .union (kernel .consts )
246
- globals_ = globals_ .union (kernel .globals )
247
-
248
- for global_ in cubin .global_dict .values ():
249
- if global_ .name in globals_ :
250
- global_asm += global_ .print () + '\n '
251
- for global_ in cubin .global_init_dict .values ():
252
- if global_ .name in globals_ :
253
- global_asm += global_ .print () + '\n '
254
- for constant in cubin .constant_dict .values ():
255
- if constant .name in consts or 'ALL_CONST3' in consts :
256
- constant_asm += constant .print () + '\n '
257
-
258
- asm = header_asm + global_asm + constant_asm + kernel_asm
259
-
260
- if strip :
261
- asm = strip_comment (asm )
262
-
263
- if out_asm_path :
264
209
with open (out_asm_path , 'w' ) as f :
265
210
f .write (asm )
266
- else :
267
- print (asm , end = '' )
268
211
269
212
270
213
def test_cubin (cubin_path , kernel_names , global_only , check = False ):
@@ -397,14 +340,8 @@ def main():
397
340
parser_as .add_argument ('-D' , '--define' , metavar = 'DEFINE' , nargs = '+' , type = str , default = '' ,
398
341
help = 'define variable for embedded python code' )
399
342
parser_as .add_argument ('-d' , '--debug' , metavar = 'OUTPUT_ASM' , type = str , default = '' , help = 'output asm for debug' )
400
- parser_as .add_argument ('-s' , '--sort' , action = 'store_true' , help = 'sort banks' )
401
-
402
- parser_pre = subparsers .add_parser ('pre' , help = 'preprocess asm' )
403
- parser_pre .add_argument ('asm' , help = 'input asm' , metavar = 'ASM' )
404
- parser_pre .add_argument ('-D' , '--define' , metavar = 'DEFINE' , nargs = '+' , type = str , default = '' ,
405
- help = 'define variable for embedded python code' )
406
- parser_pre .add_argument ('-o' , '--output' , metavar = 'OUTPUT' , type = str , default = '' , help = 'output asm file path' )
407
- parser_pre .add_argument ('-s' , '--strip' , action = 'store_true' , help = 'strip comment' )
343
+ parser_as .add_argument ('-b' , '--bank' , action = 'store_true' , help = 'sort banks' )
344
+ parser_as .add_argument ('-s' , '--strip' , action = 'store_true' , help = 'strip comment' )
408
345
409
346
parser_pdas = subparsers .add_parser ('dcc' , help = 'decompile asm to ptx' )
410
347
parser_pdas .add_argument ('asm' , help = 'input asm' , metavar = 'ASM' )
@@ -435,9 +372,7 @@ def main():
435
372
global_only = args .global_only , no_line_info = args .no_line_info )
436
373
elif args .cmd == 'as' :
437
374
assemble (asm_path = args .asm , out_cubin_path = args .output , define_list = args .define , out_asm_path = args .debug ,
438
- sort_banks = args .sort )
439
- elif args .cmd == 'pre' :
440
- preprocess (asm_path = args .asm , out_asm_path = args .output , define_list = args .define , strip = args .strip )
375
+ sort_banks = args .bank , strip = args .strip )
441
376
elif args .cmd == 'dcc' :
442
377
decompile_ptx (asm_path = args .asm , ptx_path = args .output , define_list = args .define )
443
378
elif args .cmd == 'test' :
0 commit comments