-
Notifications
You must be signed in to change notification settings - Fork 754
Scan support #16028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Scan support #16028
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16028
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Unrelated FailureAs of commit 90e55dd with merge base 9eaea4a ( NEW FAILURE - The following job has failed:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@JacobSzwejbka has imported this pull request. If you are a Meta employee, you can view this in D88107948. |
fbf2088 to
d563028
Compare
| op_table = program.execution_plan[0].operators | ||
| instructions = program.execution_plan[0].chains[0].instructions | ||
|
|
||
| # Collect all operator names in the program |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
honestly all the ops seem like implementation details and should not be tested
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was using it as a sort of a proxy that the general pattern was emitted. If you want we can just test the end 2 end behavior though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't have a strong opinion, but you might have to maintain this test if there's a change to the exported graph in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ops we are querying over are the ones /not/ in the original model definition but instead created by the emitter to maintain the semantics of scan
| 2. et_copy_index(y_outputs, combine_fn's y output, iter_idx) | ||
| This explicit copy approach is used because in-place op.out(x, out=x) is unsafe. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was under the impression that this might be fine. We basically emit scan at the very end of the lowering process and I'm not convinced we still require the graph to be functional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No the problem isnt being functional its that aten (and ET ops) are not guaranteed to work when in and out alias the same memory.
You could very easily write before read over sections of the tensor.
| meta, | ||
| ) | ||
|
|
||
| def call_scan( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@angelayi can you check that Im not doing anything stupid here
| # Use the placeholder's val which has the correct shape | ||
| xs_element_data.append(ph.meta["val"]) | ||
|
|
||
| combine_fn_result = self.call_submodule( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mostly copied torch.cond here with running call_submodul. Is this just so the subgraph also gets a chance to be run over by spec prop before callign the original? It just seems weird Im calling scan on this subgraph instead of the original one passed in as an arg
| for i in range(0, len(xs)): | ||
| ph = combine_fn_placeholders[num_init + i] | ||
| # Use the placeholder's val which has the correct shape | ||
| xs_element_data.append(ph.meta["val"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this part is a little sus where you look at the subgraph's placeholder nodes. I think the xs_element_data should just be something like, xs[0]?
Add support for higher order ops scan. Its inefficient today because we are manually deep copying from output to input for every carry. We could do better by shallow swapping the pointers but Ill do that in a follow up if needed.
Test plan: Unit tests and internal verification against harder patterns