You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Minor improvements in jagged tensor identification (#919)
Summary:
Pull Request resolved: #919
ATT. Concretely:
1. Tagging jagged tensor nodes in the fx graph now proceeds even if the nodes have a rank different from 2. This is to acomodate for the cases when jagged tensors are unsqueezed / reshaped at the path from the input (where they should be tagged for downstream shape inference) to the fbgemm op (where they are detected).
2. Inputs with the `shape[0]` being equal to one of the JT `shape[0]`, but not having an offsets tag attached are now ignored instead of failing the whole jagged tensor map inference.
3. Instead of falling back to the jagged batch-dim based JT shape inference, we now either fully rely on the inferred jagged tensor map or not at all if we fail to infer one.
4. Add `jagged_index_select` to the list of anchor ops for recognizing and tagging the jagged tensors and offsets in the fx graph.
Reviewed By: qxy11
Differential Revision: D48825713
fbshipit-source-id: 32851a0180dd47bfbc6669216cdf79817af64670
0 commit comments