Skip to content
1 change: 1 addition & 0 deletions sotodlib/mapmaking/ml_mapmaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,7 @@ def add_obs(self, id, obs, nmat, Nd):
"""
# First scan our mask to find which samples need this
# treatment
ctime = obs.timestamps
if self.recenter:
rec = smutils.evaluate_recentering(self.recenter, ctime=ctime[len(ctime) // 2],
geom=(self.mask.shape, self.mask.wcs), site=smutils.unarr(obs.site))
Expand Down
55 changes: 47 additions & 8 deletions sotodlib/mapmaking/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from typing import Optional, Any, Union
from sqlalchemy import create_engine, exc
from sqlalchemy.orm import declarative_base, Mapped, mapped_column, sessionmaker
Expand Down Expand Up @@ -377,14 +378,19 @@ def expand_ids(obs_ids, context=None, bands=None):
sub_ids.append("%s:ws%d:%s" % (obs_id, si, band))
return sub_ids

def filter_subids(subids, wafers=None, bands=None):
def filter_subids(subids, wafers=None, bands=None, ots=None):
subids = np.asarray(subids)
if wafers is not None:
wafs = astr_tok(subids,":",1)
subids = subids[np.isin(wafs, wafers)]
if bands is not None:
bpass = astr_tok(subids,":",2)
subids = subids[np.isin(bpass, bands)]
if ots is not None:
# Somewhat hacky implementation
obs_ids = astr_tok(subids, ":",0)
has_ot = np.prod([np.char.find(obs_ids, ot) for ot in ots], 0)
subids = subids[has_ot >= 0]
return subids

def astr_cat(*arrs):
Expand Down Expand Up @@ -471,7 +477,7 @@ def to_cel(lonlat, sys, ctime=None, site=None, weather=None):
if sys == "cel" or sys == "equ":
return lonlat
elif sys == "hor":
return so3g.proj.CelestialSightLine.az_el(ctime, lonlat[0], lonlat[1], site=site, weather=weather).coords()[0,:2]
return so3g.proj.CelestialSightLine.az_el(ctime, -1*lonlat[0], np.pi/2 - lonlat[1], site=site, weather=weather).coords()[0,:2]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pi/2 - el still looks wrong, so run this by @amaurea or otherwise figure out what's up.

else:
raise NotImplementedError
def get_pos(name, ctime, sys=None):
Expand All @@ -481,10 +487,16 @@ def get_pos(name, ctime, sys=None):
elif name == "auto":
return np.array([0,0]) # would use geom here
else:
obj = getattr(ephem, name)()
djd = ctime/86400 + 40587.0 + 2400000.5 - 2415020
obj.compute(djd)
return np.array([obj.a_ra, obj.a_dec])
try:
planet = coords.planets.SlowSource.for_named_source(
name, ctime)
ra0, dec0 = planet.pos(tod.timestamps.mean())
except:
obj = getattr(ephem, name)()
djd = ctime/86400 + 40587.0 + 2400000.5 - 2415020
Copy link
Member

@iparask iparask Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are all these numbers? It would be better if you include them in variables or constants that we can import

obj.compute(djd)
ra0, dec0 = obj.a_ra, obj.a_dec
return np.array([ra0, dec0])
else:
return to_cel(name, sys, ctime, site, weather)
p1 = get_pos(info["from"], ctime, info["from_sys"])
Expand Down Expand Up @@ -566,8 +578,11 @@ def rangemat_sum(rangemat):
res[i] = np.sum(ra[:,1]-ra[:,0])
return res

def find_usable_detectors(obs, maxcut=0.1, glitch_flags: str = "flags.glitch_flags"):
ncut = rangemat_sum(obs[glitch_flags])
def find_usable_detectors(obs, maxcut=0.1, glitch_flags: str = "flags.glitch_flags", to_null : str = "flags.expected_flags"):
flag = obs[glitch_flags]
if to_null != "" and to_null in obs._fields:
flag = flag*~obs[to_null]
ncut = rangemat_sum(flag)
good = ncut < obs.samps.count * maxcut
return obs.dets.vals[good]

Expand Down Expand Up @@ -800,3 +815,27 @@ def atomic_db_aux(atomic_db, info: list[AtomicInfo]):
session.commit()
except exc.IntegrityError:
session.rollback()


def prune_mpi(comm, ranks_to_keep):
"""
Prune unneeded MPI procs.

Arguments:

comm: The MPI communicator currently in use.

ranks_to_keep: List of current ranks to keep in the new communicator.

Returns:

comm: Modified communicator with only the processes we want to keep.
"""
group = comm.Get_group()
new_group = group.Incl(ranks_to_keep)
new_comm = comm.Create(new_group)
if comm.rank not in ranks_to_keep:
sys.exit(0)
comm = new_comm

return comm
31 changes: 26 additions & 5 deletions sotodlib/site_pipeline/make_ml_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def get_parser(parser=None):
parser.add_argument("prefix", nargs="?")
parser.add_argument( "--comps", type=str, default="TQU",help="List of components to solve for. T, QU or TQU, but only TQU is consistent with the actual data")
parser.add_argument("-W", "--wafers", type=str, default=None, help="Detector wafer subsets to map with. ,-sep")
parser.add_argument("-O", "--ots", type=str, default=None, help="Optics tubes to map with. ,-sep")
parser.add_argument("-B", "--bands", type=str, default=None, help="Bandpasses to map. ,-sep")
parser.add_argument("-C", "--context", type=str, default="/mnt/so1/shared/todsims/pipe-s0001/v4/context.yaml")
parser.add_argument( "--tods", type=str, default=None, help="Arbitrary slice to apply to the list of tods to analyse")
Expand All @@ -34,6 +35,7 @@ def get_parser(parser=None):
parser.add_argument("-T", "--tiled" , type=int, default=1, help="0: untiled maps. Nonzero: tiled maps")
parser.add_argument( "--srcsamp", type=str, default=None, help="path to mask file where True regions indicate where bright object mitigation should be applied. Mask is in equatorial coordinates. Not tiled, so should be low-res to not waste memory.")
parser.add_argument( "--unit", type=str, default="uK", help="Unit of the maps")
parser.add_argument( "--maxcut", type=float, default=.3, help="Maximum fraction of cut samples in a detector.")
return parser

sens_limits = {"f030":120, "f040":80, "f090":100, "f150":140, "f220":300, "f280":750}
Expand Down Expand Up @@ -112,10 +114,11 @@ class DataMissing(Exception): pass
with bench.mark('context'):
context = Context(args.context)

ots = args.ots.split(",") if args.ots else None
wafers = args.wafers.split(",") if args.wafers else None
bands = args.bands .split(",") if args.bands else None
sub_ids = mapmaking.get_subids(args.query, context=context)
sub_ids = mapmaking.filter_subids(sub_ids, wafers=wafers, bands=bands)
sub_ids = mapmaking.filter_subids(sub_ids, wafers=wafers, bands=bands, ots=ots)

# restrict tod selection further. E.g. --tods [0], --tods[:1], --tods[::100], --tods[[0,1,5,10]], etc.
if args.tods:
Expand Down Expand Up @@ -144,6 +147,7 @@ class DataMissing(Exception): pass
sys.exit(1)

passes = mapmaking.setup_passes(downsample=args.downsample, maxiter=args.maxiter, interpol=args.interpol)
to_skip = []
for ipass, passinfo in enumerate(passes):
L.info("Starting pass %d/%d maxit %d down %d interp %s" % (ipass+1, len(passes), passinfo.maxiter, passinfo.downsample, passinfo.interpol))
pass_prefix = prefix + "pass%d_" % (ipass+1)
Expand Down Expand Up @@ -190,11 +194,12 @@ class DataMissing(Exception): pass
signal_map = mapmaking.SignalMap(shape, wcs, comm, comps=comps, dtype=dtype_map, recenter=recenter, tiled=args.tiled>0, interpol=args.interpol)
signals = [signal_cut, signal_map]
if args.srcsamp:
signal_srcsamp = mapmaking.SignalSrcsamp(comm, srcsamp_mask, dtype=dtype_tod)
signal_srcsamp = mapmaking.SignalSrcsamp(comm, srcsamp_mask, recenter=recenter, dtype=dtype_tod)
signals.append(signal_srcsamp)
mapmaker = mapmaking.MLMapmaker(signals, noise_model=noise_model, dtype=dtype_tod, verbose=verbose>0)

nkept = 0
to_skip_all = comm.allreduce(to_skip)
# TODO: Fix the task distribution. The current one doesn't care which mpi
# task gets which tods, which sabotages the pixel-saving effects of tiled maps!
# To be able to distribute the tods sensibly, we need a rough estimate of where
Expand All @@ -206,6 +211,10 @@ class DataMissing(Exception): pass
obs_id, wafer, band = sub_id.split(":")
name = sub_id.replace(":", "_")
L.debug("Processing %s" % sub_id)
if sub_id in to_skip_all:
L.debug("Skipped %s (Cut in previous pass)" % (sub_id))
continue

try:
meta = context.get_meta(sub_id)
# Optionally restrict to maximum number of detectors. This is mainly
Expand All @@ -223,12 +232,14 @@ class DataMissing(Exception): pass
if obs.dets.count < 50:
L.debug("Skipped %s (Not enough detectors)" % (sub_id))
L.debug("Datacount: %s full" % (sub_id))
to_skip += [sub_id]
continue
# Check nans
mask = np.logical_not(np.isfinite(obs.signal))
if mask.sum() > 0:
L.debug("Skipped %s (a nan in signal)" % (sub_id))
L.debug("Datacount: %s full" % (sub_id))
to_skip += [sub_id]
continue
# Check all 0s
zero_dets = np.sum(obs.signal, axis=1)
Expand Down Expand Up @@ -290,6 +301,7 @@ class DataMissing(Exception): pass
if np.logical_not(good).sum() / obs.dets.count > 0.5:
L.debug("Skipped %s (more than 50pc of detectors cut by sens)" % (sub_id))
L.debug("Datacount: %s full" % (sub_id))
to_skip += [sub_id]
continue
else:
obs.restrict("dets", good)
Expand All @@ -300,9 +312,10 @@ class DataMissing(Exception): pass
# Calibrate to K_cmb
#obs.signal = np.multiply(obs.signal.T, obs.abscal.abscal_cmb).T
# Disqualify overly cut detectors
good_dets = mapmaking.find_usable_detectors(obs, maxcut=0.3)
good_dets = mapmaking.find_usable_detectors(obs, args.maxcut)
obs.restrict("dets", good_dets)
if obs.dets.count == 0:
to_skip += [sub_id]
L.debug("Skipped %s (all dets cut)" % (sub_id))
L.debug("Datacount: %s full" % (sub_id))
continue
Expand Down Expand Up @@ -380,11 +393,19 @@ class DataMissing(Exception): pass
L.debug("Datacount: %s full" % (sub_id))
continue

nkept = comm.allreduce(nkept)
if nkept == 0:
nkept_all = np.array(comm.allgather(nkept))
if np.sum(nkept_all) == 0:
if comm.rank == 0:
L.info("All tods failed. Giving up")
sys.exit(1)

if np.any(nkept_all == 0):
if nkept == 0:
L.info("No tods assigned to this process. Pruning")
comm = mapmaking.prune_mpi(comm, np.where(nkept_all > 0)[0])
for signal in mapmaker.signals:
if hasattr(signal, "comm"):
signal.comm = comm

L.info("Done building")

Expand Down