33from typing import TYPE_CHECKING
44
55from openprotein .base import APISession
6- from openprotein .common import FeatureType , ModelMetadata , ReductionType
6+ from openprotein .common import (
7+ Feature ,
8+ FeatureType ,
9+ ModelMetadata ,
10+ Reduction ,
11+ ReductionType ,
12+ )
713from openprotein .data import AssayDataset , AssayMetadata , DataAPI
814from openprotein .errors import InvalidParameterError
915
@@ -199,9 +205,9 @@ def logits(
199205 def fit_svd (
200206 self ,
201207 sequences : list [bytes ] | list [str ] | None = None ,
202- assay : AssayDataset | None = None ,
208+ assay : AssayDataset | AssayMetadata | None = None ,
203209 n_components : int = 1024 ,
204- reduction : ReductionType | None = None ,
210+ reduction : Reduction | ReductionType | None = None ,
205211 ** kwargs ,
206212 ) -> "SVDModel" :
207213 """
@@ -236,6 +242,11 @@ def fit_svd(
236242 # local import for cyclic dep
237243 from openprotein .svd import SVDAPI
238244
245+ # runtime check on value
246+ if isinstance (reduction , str ):
247+ reduction = ReductionType (reduction )
248+ reduction = reduction .value
249+
239250 svd_api = getattr (self .session , "svd" , None )
240251 assert isinstance (svd_api , SVDAPI )
241252
@@ -246,9 +257,8 @@ def fit_svd(
246257 raise InvalidParameterError (
247258 "Expected either assay or sequences to fit SVD on!"
248259 )
249- model_id = self .id
250260 return svd_api .fit_svd (
251- model_id = model_id ,
261+ model = self ,
252262 sequences = sequences ,
253263 assay = assay ,
254264 n_components = n_components ,
@@ -259,9 +269,9 @@ def fit_svd(
259269 def fit_umap (
260270 self ,
261271 sequences : list [bytes ] | list [str ] | None = None ,
262- assay : AssayDataset | None = None ,
272+ assay : AssayDataset | AssayMetadata | None = None ,
263273 n_components : int = 2 ,
264- reduction : ReductionType | None = ReductionType . MEAN ,
274+ reduction : Reduction | ReductionType = " MEAN" ,
265275 ** kwargs ,
266276 ) -> "UMAPModel" :
267277 """
@@ -274,11 +284,11 @@ def fit_umap(
274284 ----------
275285 sequences : list of bytes or list of str or None, optional
276286 Optional sequences to fit UMAP with. Either use sequences or assay. Sequences is preferred.
277- assay : AssayDataset or None, optional
287+ assay : AssayDataset or AssayMetadata or None, optional
278288 Optional assay containing sequences to fit UMAP with. Either use sequences or assay. Ignored if sequences are provided.
279289 n_components : int, optional
280290 Number of components in UMAP fit. Determines output shapes. Default is 2.
281- reduction : ReductionType or None, optional
291+ reduction : Reduction or ReductionType or None, optional
282292 Embeddings reduction to use (e.g. mean). Defaults to MEAN.
283293 kwargs :
284294 Additional keyword arguments to be used from foundational models, e.g. prompt_id for PoET models.
@@ -296,6 +306,16 @@ def fit_umap(
296306 # local import for cyclic dep
297307 from openprotein .umap import UMAPAPI
298308
309+ if reduction is None :
310+ raise InvalidParameterError (
311+ "Expected reduction if using EmbeddingModel to fit UMAP"
312+ )
313+
314+ # runtime check on value
315+ if isinstance (reduction , str ):
316+ reduction = ReductionType (reduction )
317+ reduction = reduction .value
318+
299319 umap_api = getattr (self .session , "umap" , None )
300320 assert isinstance (umap_api , UMAPAPI )
301321
@@ -306,20 +326,26 @@ def fit_umap(
306326 raise InvalidParameterError (
307327 "Expected either assay or sequences to fit UMAP on!"
308328 )
329+ # get assay_id
330+ assay_id = (
331+ assay .assay_id
332+ if isinstance (assay , AssayMetadata )
333+ else assay .id if isinstance (assay , AssayDataset ) else assay
334+ )
309335 model_id = self .id
310336 return umap_api .fit_umap (
311337 model_id = model_id ,
312338 feature_type = FeatureType .PLM ,
313339 sequences = sequences ,
314- assay_id = assay . id if assay is not None else None ,
340+ assay_id = assay_id ,
315341 n_components = n_components ,
316342 reduction = reduction ,
317343 ** kwargs ,
318344 )
319345
320346 def fit_gp (
321347 self ,
322- assay : AssayMetadata | AssayDataset | str ,
348+ assay : AssayDataset | AssayMetadata | str ,
323349 properties : list [str ],
324350 reduction : ReductionType ,
325351 name : str | None = None ,
@@ -358,26 +384,9 @@ def fit_gp(
358384 # local import to resolve cyclic
359385 from openprotein .predictor import PredictorAPI
360386
361- data_api = getattr (self .session , "data" , None )
362- assert isinstance (data_api , DataAPI )
363387 predictor_api = getattr (self .session , "predictor" , None )
364388 assert isinstance (predictor_api , PredictorAPI )
365389
366- # get assay if str
367- assay = data_api .get (assay_id = assay ) if isinstance (assay , str ) else assay
368- # extract assay_id
369- if len (properties ) == 0 :
370- raise InvalidParameterError ("Expected (at-least) 1 property to train" )
371- if not set (properties ) <= set (assay .measurement_names ):
372- raise InvalidParameterError (
373- f"Expected all provided properties to be a subset of assay's measurements: { assay .measurement_names } "
374- )
375- # TODO - support multitask
376- if len (properties ) > 1 :
377- raise InvalidParameterError (
378- "Training a multitask GP is not yet supported (i.e. number of properties should only be 1 for now)"
379- )
380-
381390 # inject into predictor api
382391 return predictor_api .fit_gp (
383392 assay = assay ,
0 commit comments