@@ -338,7 +338,9 @@ def groupby_reduce(
338
338
f"the number of partitions along { axis = } is not equal: "
339
339
+ f"{ partitions .shape [axis ]} != { by .shape [axis ]} "
340
340
)
341
- mapped_partitions = cls .apply (axis , map_func , left = partitions , right = by )
341
+ mapped_partitions = cls .broadcast_apply (
342
+ axis , map_func , left = partitions , right = by
343
+ )
342
344
else :
343
345
mapped_partitions = cls .map_partitions (partitions , map_func )
344
346
@@ -437,7 +439,7 @@ def get_partitions(index):
437
439
438
440
@classmethod
439
441
@wait_computations_if_benchmark_mode
440
- def broadcast_apply (cls , axis , apply_func , left , right ):
442
+ def base_broadcast_apply (cls , axis , apply_func , left , right ):
441
443
"""
442
444
Broadcast the `right` partitions to `left` and apply `apply_func` function.
443
445
@@ -490,57 +492,6 @@ def map_func(df, *others):
490
492
]
491
493
)
492
494
493
- @classmethod
494
- @wait_computations_if_benchmark_mode
495
- def apply_axis_partitions (
496
- cls ,
497
- axis ,
498
- apply_func ,
499
- left ,
500
- right ,
501
- ):
502
- """
503
- Broadcast the `right` partitions to `left` and apply `apply_func` along full `axis`.
504
-
505
- Parameters
506
- ----------
507
- axis : {0, 1}
508
- Axis to apply and broadcast over.
509
- apply_func : callable
510
- Function to apply.
511
- left : NumPy 2D array
512
- Left partitions.
513
- right : NumPy 2D array
514
- Right partitions.
515
-
516
- Returns
517
- -------
518
- NumPy array
519
- An array of partition objects.
520
-
521
- Notes
522
- -----
523
- This method differs from `broadcast_axis_partitions` in that it does not send
524
- all right partitions for each remote task based on the left partitions.
525
- """
526
- preprocessed_map_func = cls .preprocess_func (apply_func )
527
- left_partitions = cls .axis_partition (left , axis )
528
- right_partitions = None if right is None else cls .axis_partition (right , axis )
529
-
530
- result_blocks = np .array (
531
- [
532
- left_partitions [i ].apply (
533
- preprocessed_map_func ,
534
- other_axis_partition = right_partitions [i ],
535
- )
536
- for i in np .arange (len (left_partitions ))
537
- ]
538
- )
539
- # If we are mapping over columns, they are returned to use the same as
540
- # rows, so we need to transpose the returned 2D NumPy array to return
541
- # the structure to the correct order.
542
- return result_blocks .T if not axis else result_blocks
543
-
544
495
@classmethod
545
496
@wait_computations_if_benchmark_mode
546
497
def broadcast_axis_partitions (
@@ -552,6 +503,7 @@ def broadcast_axis_partitions(
552
503
keep_partitioning = False ,
553
504
num_splits = None ,
554
505
apply_indices = None ,
506
+ send_all_right = True ,
555
507
enumerate_partitions = False ,
556
508
lengths = None ,
557
509
apply_func_args = None ,
@@ -580,6 +532,8 @@ def broadcast_axis_partitions(
580
532
then the number of splits is preserved.
581
533
apply_indices : list of ints, default: None
582
534
Indices of `axis ^ 1` to apply function over.
535
+ send_all_right: bool, default: True
536
+ Whether or not to pass all right axis partitions to each of the left axis partitions.
583
537
enumerate_partitions : bool, default: False
584
538
Whether or not to pass partition index into `apply_func`.
585
539
Note that `apply_func` must be able to accept `partition_idx` kwarg.
@@ -626,7 +580,6 @@ def broadcast_axis_partitions(
626
580
# load-balance the data as well.
627
581
kw = {
628
582
"num_splits" : num_splits ,
629
- "other_axis_partition" : right_partitions ,
630
583
"maintain_partitioning" : keep_partitioning ,
631
584
}
632
585
if lengths :
@@ -641,6 +594,9 @@ def broadcast_axis_partitions(
641
594
left_partitions [i ].apply (
642
595
preprocessed_map_func ,
643
596
* (apply_func_args if apply_func_args else []),
597
+ other_axis_partition = (
598
+ right_partitions if send_all_right else right_partitions [i ]
599
+ ),
644
600
** kw ,
645
601
** ({"partition_idx" : idx } if enumerate_partitions else {}),
646
602
** kwargs ,
@@ -698,7 +654,7 @@ def base_map_partitions(
698
654
699
655
@classmethod
700
656
@wait_computations_if_benchmark_mode
701
- def apply (
657
+ def broadcast_apply (
702
658
cls ,
703
659
axis ,
704
660
apply_func ,
@@ -731,24 +687,22 @@ def apply(
731
687
# partitions of the left and right dataframes are possible for the `apply`,
732
688
# as a result of which it is necessary to merge partitions on both axes at once,
733
689
# which leads to large slowdowns.
734
- if (
735
- np .prod (left .shape ) <= 1.5 * CpuCount .get ()
736
- or left .shape [axis ] < CpuCount .get () // 5
737
- ):
690
+ if np .prod (left .shape ) <= 1.5 * CpuCount .get ():
738
691
# block-wise broadcast
739
- new_partitions = cls .broadcast_apply (
692
+ new_partitions = cls .base_broadcast_apply (
740
693
axis ,
741
694
apply_func ,
742
695
left ,
743
696
right ,
744
697
)
745
698
else :
746
699
# axis-wise broadcast
747
- new_partitions = cls .apply_axis_partitions (
700
+ new_partitions = cls .broadcast_axis_partitions (
748
701
axis = axis ^ 1 ,
749
702
left = left ,
750
703
right = right ,
751
704
apply_func = apply_func ,
705
+ send_all_right = False ,
752
706
)
753
707
return new_partitions
754
708
0 commit comments