Skip to content

Commit 6bbe03c

Browse files
aakhundovfacebook-github-bot
authored andcommitted
Minor improvements in jagged tensor identification (#919)
Summary: Pull Request resolved: #919 ATT. Concretely: 1. Tagging jagged tensor nodes in the fx graph now proceeds even if the nodes have a rank different from 2. This is to acomodate for the cases when jagged tensors are unsqueezed / reshaped at the path from the input (where they should be tagged for downstream shape inference) to the fbgemm op (where they are detected). 2. Inputs with the `shape[0]` being equal to one of the JT `shape[0]`, but not having an offsets tag attached are now ignored instead of failing the whole jagged tensor map inference. 3. Instead of falling back to the jagged batch-dim based JT shape inference, we now either fully rely on the inferred jagged tensor map or not at all if we fail to infer one. 4. Add `jagged_index_select` to the list of anchor ops for recognizing and tagging the jagged tensors and offsets in the fx graph. Reviewed By: qxy11 Differential Revision: D48825713 fbshipit-source-id: 32851a0180dd47bfbc6669216cdf79817af64670
1 parent 0f91ea5 commit 6bbe03c

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

fx2ait/fx2ait/tensor_spec.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -345,9 +345,12 @@ def _try_getting_jagged_tensor_map(
345345
for i, inp in enumerate(inputs):
346346
if inp.shape[0] in jagged_tensor_batch_dims:
347347
offsets_name = fx_inputs[i].meta.get("offsets_name", None)
348-
if offsets_name is None or len(offsets_name) > 1:
348+
if offsets_name is None:
349+
# not a jagged tensor
350+
continue
351+
if len(offsets_name) > 1:
349352
# offsets name attached to the jagged tensor's
350-
# fx.Node is either unavailable or ambiguous
353+
# fx.Node is either ambiguous: failing here
351354
return None
352355
offsets_name = list(offsets_name)[0]
353356
if offsets_name not in seen_offsets_names:
@@ -387,7 +390,7 @@ def from_input_list_with_batch_size_jagged_tensor(
387390
jagged_tensor_batch_dims=jagged_tensor_batch_dims,
388391
fx_inputs=fx_inputs,
389392
)
390-
if jagged_tensor_map is not None:
393+
if jagged_tensor_map:
391394
logger.info("Successfully detected a jagged_tensor_map:")
392395
for input_id, jagged_tensor_id in jagged_tensor_map.items():
393396
logger.info(f"{input_id=}, {jagged_tensor_id=}")
@@ -407,7 +410,7 @@ def from_input_list_with_batch_size_jagged_tensor(
407410
batch_dim_lower_bound: int = 0
408411
batch_dim_upper_bound: int = 0
409412
batch_dim_name: str = ""
410-
if jagged_tensor_map is not None and ind in jagged_tensor_map:
413+
if jagged_tensor_map and ind in jagged_tensor_map:
411414
batch_dim_lower_bound = 0 # when all sequences are empty
412415
# if the maximum sequence length for this jagged tensor was not
413416
# inferred from the offsets, we use the globally configured
@@ -417,7 +420,7 @@ def from_input_list_with_batch_size_jagged_tensor(
417420
)
418421
batch_dim_upper_bound = max_batch_size * max_seq_len
419422
batch_dim_name = f"batch_size_jagged_tensor_id_{jagged_tensor_map[ind]}"
420-
elif batch_dim in jagged_tensor_batch_dims:
423+
elif not jagged_tensor_map and batch_dim in jagged_tensor_batch_dims:
421424
batch_dim_lower_bound = 0
422425
max_seq_len = max_seq_lens_from_offsets.get(
423426
batch_dim, max_sequence_length

0 commit comments

Comments
 (0)