Skip to content

Conversation

@jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented May 15, 2025

Description

A pain point for me when testing different algorithms (e.g. MCMC vs VI) is that I don't want to write a 2nd version of the model with pm.Minibatch on the data.

This PR adds a model transformation that does that for the user. It's the reverse of the remove_minibatched_nodes transformer that @zaxtax implemented recently.

This is a WIP, it doesn't actually work now, because I can't figure out how to rebuild the observed variable with the total_size set correctly. Help wanted.

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7785.org.readthedocs.build/en/7785/

@jessegrabowski jessegrabowski requested a review from zaxtax May 15, 2025 12:28
@ricardoV94
Copy link
Member

This is a WIP, it doesn't actually work now, because I can't figure out how to rebuild the observed variable with the total_size set correctly. Help wanted.

You can use the lower level utility:

def create_minibatch_rv(

Then make that a vanilla observed RV

@ricardoV94
Copy link
Member

Ah you already did that, so your question is how to get total size? Grab the batch shape of the variable and constant fold it without raising if it can't be fully folded

@jessegrabowski
Copy link
Member Author

My real issue was not understanding what needs to be the key and value in the replacements, between:

  1. The model variable
  2. The memo variable
  3. The fgraph variable

@ricardoV94
Copy link
Member

ricardoV94 commented May 15, 2025

the best is usual to replace the whole fgraph ModelObservedRV by a new one. You probably have to discard any dims on the batch dimension which is an input to that op

@jessegrabowski
Copy link
Member Author

I don't really understand what that answer means

@ricardoV94
Copy link
Member

dprint the fgraph and it will perhaps be more obvious what I am mumbling

@jessegrabowski
Copy link
Member Author

The problem i was running into was that I ended up with two beta RVs after doing the replace. Beta was the only RV implicated in the ModelObservedRV sub-graph

@zaxtax
Copy link
Contributor

zaxtax commented May 15, 2025 via email

@zaxtax zaxtax force-pushed the model-to-minibatch branch from c1168de to 8d1b479 Compare June 9, 2025 12:52
minibatch_vars = Minibatch(*data_vars, batch_size=batch_size)
replacements = {datum: minibatch_vars[i] for i, datum in enumerate(data_vars)}
assert 0
# Add total_size to all observed RVs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should only add to those that depend on the minibatch data no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct thing would be a dim analysis like we do for MarginaModel to confirm the first dim of the data maps to the first dim of the observed rvs, which is when the rewrite is valid. We may not want to do that, but we should be clear about the assumptions in the docstrings.

Example where minibatch rewrite will fail / do the wrong thing, is if you tranpose the data before you used it in the observations.

replacements = {datum: minibatch_vars[i] for i, datum in enumerate(data_vars)}
assert 0
# Add total_size to all observed RVs
total_size = data_vars[0].get_value().shape[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

total size can be symbolic I think?


data_vars = [
memo[datum].owner.inputs[0]
for datum in (model.named_vars[datum_name] for datum_name in model.named_vars)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a model.data_vars. You should however allow users to specify which data vars to be minibatched (default to all is fine). Alternatively we could restrict this to models with dims, and the user has to tell us which dim is being minibatched?

That makes the graph analysis easier

@zaxtax
Copy link
Contributor

zaxtax commented Jun 11, 2025 via email

@codecov
Copy link

codecov bot commented Nov 16, 2025

Codecov Report

❌ Patch coverage is 1.81818% with 54 lines in your changes missing coverage. Please review.
✅ Project coverage is 91.22%. Comparing base (869503b) to head (0fbc7d9).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
pymc/model/transform/minibatch.py 0.00% 54 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7785      +/-   ##
==========================================
- Coverage   91.49%   91.22%   -0.27%     
==========================================
  Files         116      117       +1     
  Lines       18962    18999      +37     
==========================================
- Hits        17349    17332      -17     
- Misses       1613     1667      +54     
Files with missing lines Coverage Δ
pymc/model/transform/basic.py 95.00% <100.00%> (-2.30%) ⬇️
pymc/model/transform/minibatch.py 0.00% <0.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

pymc/data.py Outdated
Comment on lines 95 to 96
# FIXME: __props__ should not be empty
__props__ = ()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The underlying issue is that OpFromGraph doesn't have equality implemented: pymc-devs/pytensor#1606

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You know this but for future readers, the reason these lines were added in this PR was to let the assert_model_equality check to pass on models with MinibatchRV, which is an OpFromGraph

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but we should remove this from the PR and test differently

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 17, 2025

I pushed intermediate changes that I'll clean better, probably broken atm

@ricardoV94
Copy link
Member

I pulled changes from pymc-devs/pymc-extras#211

In general this rewrite is type-unsafe. If the variable you're trying to minibatch has static shape you can't apply the rewrite. Instead of doing a lame failure I added the functionality from that draft PR to rebuild the graph when types change. This is a very useful functionality that together with toposort_replace we probably want to upstream to PyTensor later.

This can be used in the remove_minibatch pre-existing transform, which was doing clone_replace. This rewrite wasn't complete, because it didn't remove minibatch RVs which it should as well. To do that properly we need the new implementation because we also need to substitute the non-minibatch variables in the minibatch RV graph.

I'll clean everything. One small thing still failing in the new code

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants