-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Add model_to_minibatch transformation to convert all pm.Data to pm.Minibatch
#7785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
You can use the lower level utility: pymc/pymc/variational/minibatch_rv.py Line 53 in ef26ae8
Then make that a vanilla observed RV |
|
Ah you already did that, so your question is how to get total size? Grab the batch shape of the variable and constant fold it without raising if it can't be fully folded |
|
My real issue was not understanding what needs to be the key and value in the replacements, between:
|
|
the best is usual to replace the whole fgraph |
|
I don't really understand what that answer means |
|
dprint the fgraph and it will perhaps be more obvious what I am mumbling |
|
The problem i was running into was that I ended up with two |
|
Because Minibatch assumes the data variables have the same length, it might make sense to take a variables argument. Or have some way to group data variables of the same size (same dim name maybe?)
…On Thu, 15 May 2025, 15:35 Ricardo Vieira, ***@***.***> wrote:
*ricardoV94* left a comment (pymc-devs/pymc#7785)
<#7785 (comment)>
dprint the fgraph and it will perhaps be more obvious what I am mumbling
—
Reply to this email directly, view it on GitHub
<#7785 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACCUMC5VCN6VAAJKNHEMT26SJPZAVCNFSM6AAAAAB5F7LYYKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDQOBTHAZTINZXG4>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
c1168de to
8d1b479
Compare
pymc/model/transform/basic.py
Outdated
| minibatch_vars = Minibatch(*data_vars, batch_size=batch_size) | ||
| replacements = {datum: minibatch_vars[i] for i, datum in enumerate(data_vars)} | ||
| assert 0 | ||
| # Add total_size to all observed RVs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should only add to those that depend on the minibatch data no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The correct thing would be a dim analysis like we do for MarginaModel to confirm the first dim of the data maps to the first dim of the observed rvs, which is when the rewrite is valid. We may not want to do that, but we should be clear about the assumptions in the docstrings.
Example where minibatch rewrite will fail / do the wrong thing, is if you tranpose the data before you used it in the observations.
pymc/model/transform/basic.py
Outdated
| replacements = {datum: minibatch_vars[i] for i, datum in enumerate(data_vars)} | ||
| assert 0 | ||
| # Add total_size to all observed RVs | ||
| total_size = data_vars[0].get_value().shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
total size can be symbolic I think?
pymc/model/transform/basic.py
Outdated
|
|
||
| data_vars = [ | ||
| memo[datum].owner.inputs[0] | ||
| for datum in (model.named_vars[datum_name] for datum_name in model.named_vars) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a model.data_vars. You should however allow users to specify which data vars to be minibatched (default to all is fine). Alternatively we could restrict this to models with dims, and the user has to tell us which dim is being minibatched?
That makes the graph analysis easier
|
Yep, I have reworked this code and need to push my changes!
…On Wed, 11 Jun 2025, 23:07 Ricardo Vieira, ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In pymc/model/transform/basic.py
<#7785 (comment)>:
> @@ -62,6 +66,47 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l
return [model[var] if isinstance(var, str) else var for var in vars_seq]
+def model_to_minibatch(model: Model, batch_size: int) -> Model:
+ """Replace all Data containers with pm.Minibatch, and add total_size to all observed RVs."""
+ from pymc.variational.minibatch_rv import create_minibatch_rv
+
+ fgraph, memo = fgraph_from_model(model, inlined_views=True)
+
+ # obs_rvs, data_vars = model.rvs_to_values.items()
+
+ data_vars = [
+ memo[datum].owner.inputs[0]
+ for datum in (model.named_vars[datum_name] for datum_name in model.named_vars)
There's a model.data_vars. You should however allow users to specify
which data vars to be minibatched (default to all is fine). Alternatively
we could restrict this to models with dims, and the user has to tell us
which dim is being minibatched?
That makes the graph analysis easier
—
Reply to this email directly, view it on GitHub
<#7785 (review)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACCUKBANF33XOQRR2ISCD3DCK7PAVCNFSM6AAAAAB5F7LYYKVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDSMJYG42DKNRZGA>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
8d1b479 to
9df25b9
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7785 +/- ##
==========================================
- Coverage 91.49% 91.22% -0.27%
==========================================
Files 116 117 +1
Lines 18962 18999 +37
==========================================
- Hits 17349 17332 -17
- Misses 1613 1667 +54
🚀 New features to boost your workflow:
|
pymc/data.py
Outdated
| # FIXME: __props__ should not be empty | ||
| __props__ = () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The underlying issue is that OpFromGraph doesn't have equality implemented: pymc-devs/pytensor#1606
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You know this but for future readers, the reason these lines were added in this PR was to let the assert_model_equality check to pass on models with MinibatchRV, which is an OpFromGraph
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah but we should remove this from the PR and test differently
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure
This reverts commit 4d18e37.
7e581f6 to
0fbc7d9
Compare
|
I pushed intermediate changes that I'll clean better, probably broken atm |
2fc116c to
a3bc54b
Compare
|
I pulled changes from pymc-devs/pymc-extras#211 In general this rewrite is type-unsafe. If the variable you're trying to minibatch has static shape you can't apply the rewrite. Instead of doing a lame failure I added the functionality from that draft PR to rebuild the graph when types change. This is a very useful functionality that together with This can be used in the I'll clean everything. One small thing still failing in the new code |
Description
A pain point for me when testing different algorithms (e.g. MCMC vs VI) is that I don't want to write a 2nd version of the model with
pm.Minibatchon the data.This PR adds a model transformation that does that for the user. It's the reverse of the
remove_minibatched_nodestransformer that @zaxtax implemented recently.This is a WIP, it doesn't actually work now, because I can't figure out how to rebuild the observed variable with the
total_sizeset correctly. Help wanted.Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7785.org.readthedocs.build/en/7785/