|
2 | 2 | import random |
3 | 3 | from typing import BinaryIO, Iterator |
4 | 4 |
|
5 | | -import openprotein.config as config |
6 | 5 | import requests |
| 6 | + |
| 7 | +import openprotein.config as config |
7 | 8 | from openprotein.base import APISession |
8 | 9 | from openprotein.csv import csv_stream |
9 | 10 | from openprotein.errors import APIError, InvalidParameterError, MissingParameterError |
@@ -81,34 +82,6 @@ def get_input( |
81 | 82 | return csv_stream(response) |
82 | 83 |
|
83 | 84 |
|
84 | | -def get_prompt( |
85 | | - session: APISession, job: Job, prompt_index: int | None = None |
86 | | -) -> Iterator[list[str]]: |
87 | | - """ |
88 | | - Get the prompt for a given job. |
89 | | -
|
90 | | - Parameters |
91 | | - ---------- |
92 | | - session : APISession |
93 | | - The API session. |
94 | | - job : Job |
95 | | - The job for which to retrieve the prompt. |
96 | | - prompt_index : Optional[int], default=None |
97 | | - The index of the prompt. If None, it returns all. |
98 | | -
|
99 | | - Returns |
100 | | - ------- |
101 | | - Iterator[list[str]] |
102 | | - A CSV reader for the prompt data. |
103 | | - """ |
104 | | - return get_input( |
105 | | - session=session, |
106 | | - job=job, |
107 | | - input_type=AlignType.PROMPT, |
108 | | - prompt_index=prompt_index, |
109 | | - ) |
110 | | - |
111 | | - |
112 | 85 | def get_seed(session: APISession, job: Job) -> Iterator[list[str]]: |
113 | 86 | """ |
114 | 87 | Get the seed for a given MSA job. |
@@ -197,6 +170,135 @@ def msa_post( |
197 | 170 | return Job.model_validate(response.json()) |
198 | 171 |
|
199 | 172 |
|
| 173 | +# TODO - document the `ep` and `op` parameters |
| 174 | +def mafft_post( |
| 175 | + session: APISession, |
| 176 | + sequence_file: BinaryIO, |
| 177 | + auto: bool = True, |
| 178 | + ep: float | None = None, |
| 179 | + op: float | None = None, |
| 180 | +) -> Job: |
| 181 | + """ |
| 182 | + Align sequences using the `mafft` algorithm. Sequences can be provided as `fasta` or `csv` formats. If `csv`, the file must be headerless with either a single sequence column or name, sequence columns. |
| 183 | +
|
| 184 | + Set auto to True to automatically attempt the best params. Leave a parameter as None to use system defaults. |
| 185 | +
|
| 186 | + Parameters |
| 187 | + ---------- |
| 188 | + session : APISession |
| 189 | + sequence_file : BinaryIO |
| 190 | + Sequences to align in fasta or csv format. |
| 191 | + auto : bool = True, optional |
| 192 | + Set to true to automatically set algorithm parameters. |
| 193 | + ep : float, optional |
| 194 | + mafft parameter |
| 195 | + op : float, optional |
| 196 | + mafft parameter |
| 197 | +
|
| 198 | + Returns |
| 199 | + ------- |
| 200 | + Job |
| 201 | + Job details. |
| 202 | + """ |
| 203 | + endpoint = "v1/align/mafft" |
| 204 | + |
| 205 | + files = {"file": sequence_file} |
| 206 | + params: dict = {"auto": auto} |
| 207 | + if ep is not None: |
| 208 | + params["ep"] = ep |
| 209 | + if op is not None: |
| 210 | + params["op"] = op |
| 211 | + |
| 212 | + response = session.post(endpoint, files=files, params=params) |
| 213 | + return Job.model_validate(response.json()) |
| 214 | + |
| 215 | + |
| 216 | +# TODO - document the `clustersize` and `iterations` parameters |
| 217 | +def clustalo_post( |
| 218 | + session: APISession, |
| 219 | + sequence_file: BinaryIO, |
| 220 | + clustersize: int | None = None, |
| 221 | + iterations: int | None = None, |
| 222 | +) -> Job: |
| 223 | + """ |
| 224 | + Align sequences using the `clustal omega` algorithm. Sequences can be provided as `fasta` or `csv` formats. If `csv`, the file must be headerless with either a single sequence column or name, sequence columns. |
| 225 | +
|
| 226 | + Leave a parameter as None to use system defaults. |
| 227 | +
|
| 228 | + Parameters |
| 229 | + ---------- |
| 230 | + session : APISession |
| 231 | + sequence_file : BinaryIO |
| 232 | + Sequences to align in fasta or csv format. |
| 233 | + clustersize : int, optional |
| 234 | + clustal omega parameter |
| 235 | + iterations : int, optional |
| 236 | + clustal omega parameter |
| 237 | +
|
| 238 | + Returns |
| 239 | + ------- |
| 240 | + Job |
| 241 | + Job details. |
| 242 | + """ |
| 243 | + endpoint = "v1/align/clustalo" |
| 244 | + |
| 245 | + files = {"file": sequence_file} |
| 246 | + params = {} |
| 247 | + if clustersize is not None: |
| 248 | + params["clustersize"] = clustersize |
| 249 | + if iterations is not None: |
| 250 | + params["iterations"] = iterations |
| 251 | + |
| 252 | + response = session.post(endpoint, files=files, params=params) |
| 253 | + return Job.model_validate(response.json()) |
| 254 | + |
| 255 | + |
| 256 | +def abnumber_post( |
| 257 | + session: APISession, |
| 258 | + sequence_file: BinaryIO, |
| 259 | + scheme: str = "imgt", |
| 260 | +) -> Job: |
| 261 | + """ |
| 262 | + Align antibody using `AbNumber`. Sequences can be provided as `fasta` or `csv` formats. If `csv`, the file must be headerless with either a single sequence column or name, sequence columns. |
| 263 | +
|
| 264 | + The antibody numbering scheme can be specified from `imgt` (default), `chothia`, `kabat`, or `aho`. |
| 265 | +
|
| 266 | + Parameters |
| 267 | + ---------- |
| 268 | + session : APISession |
| 269 | + sequence_file : BinaryIO |
| 270 | + Sequences to align in fasta or csv format. |
| 271 | + scheme : str = 'imgt' |
| 272 | + Antibody numbering scheme. Can be one of 'imgt', 'chothia', 'kabat', or 'aho' |
| 273 | +
|
| 274 | + Returns |
| 275 | + ------- |
| 276 | + Job |
| 277 | + Job details. |
| 278 | + """ |
| 279 | + endpoint = "v1/align/abnumber" |
| 280 | + |
| 281 | + valid_schemes = ["imgt", "chothia", "kabat", "aho"] |
| 282 | + if scheme not in valid_schemes: |
| 283 | + raise Exception( |
| 284 | + f"Antibody numbering {scheme} not recognized. Must be one of {valid_schemes}." |
| 285 | + ) |
| 286 | + |
| 287 | + files = {"file": sequence_file} |
| 288 | + params = {"scheme": scheme} |
| 289 | + |
| 290 | + response = session.post(endpoint, files=files, params=params) |
| 291 | + return Job.model_validate(response.json()) |
| 292 | + |
| 293 | + |
| 294 | +# TODO - implement support for getting the antibody numbering from an `AbNumber` job |
| 295 | +def antibody_schema_get(session: APISession, job_id: str): |
| 296 | + """ |
| 297 | + Return the antibody numbering for an `AbNumber` job. |
| 298 | + """ |
| 299 | + raise NotImplementedError() |
| 300 | + |
| 301 | + |
200 | 302 | def prompt_post( |
201 | 303 | session: APISession, |
202 | 304 | msa_id: str, |
@@ -296,126 +398,3 @@ def prompt_post( |
296 | 398 |
|
297 | 399 | response = session.post(endpoint, params=params) |
298 | 400 | return Job.model_validate(response.json()) |
299 | | - |
300 | | - |
301 | | -def upload_prompt_post( |
302 | | - session: APISession, |
303 | | - prompt_file: BinaryIO, |
304 | | -): |
305 | | - """ |
306 | | - Directly upload a prompt. |
307 | | -
|
308 | | - Bypass post_msa and prompt_post steps entirely. In this case PoET will use the prompt as is. |
309 | | - You can specify multiple prompts (one per replicate) with an `<END_PROMPT>\n` between CSVs. |
310 | | -
|
311 | | - Parameters |
312 | | - ---------- |
313 | | - session : APISession |
314 | | - An instance of APISession to manage interactions with the API. |
315 | | - prompt_file : BinaryIO |
316 | | - Binary I/O object representing the prompt file. |
317 | | -
|
318 | | - Raises |
319 | | - ------ |
320 | | - APIError |
321 | | - If there is an issue with the API request. |
322 | | -
|
323 | | - Returns |
324 | | - ------- |
325 | | - Job |
326 | | - An object representing the status and results of the prompt job. |
327 | | - """ |
328 | | - |
329 | | - endpoint = "v1/align/upload_prompt" |
330 | | - files = {"prompt_file": prompt_file} |
331 | | - try: |
332 | | - response = session.post(endpoint, files=files) |
333 | | - return Job.model_validate(response.json()) |
334 | | - except Exception as exc: |
335 | | - raise APIError(f"Failed to upload prompt post: {exc}") from exc |
336 | | - |
337 | | - |
338 | | -def poet_score_post( |
339 | | - session: APISession, prompt_id: str, queries: list[bytes | str] |
340 | | -) -> Job: |
341 | | - """ |
342 | | - Submits a job to score sequences based on the given prompt. |
343 | | -
|
344 | | - Parameters |
345 | | - ---------- |
346 | | - session : APISession |
347 | | - An instance of APISession to manage interactions with the API. |
348 | | - prompt_id : str |
349 | | - The ID of the prompt. |
350 | | - queries : List[str] |
351 | | - A list of query sequences to be scored. |
352 | | -
|
353 | | - Raises |
354 | | - ------ |
355 | | - APIError |
356 | | - If there is an issue with the API request. |
357 | | -
|
358 | | - Returns |
359 | | - ------- |
360 | | - Job |
361 | | - An object representing the status and results of the scoring job. |
362 | | - """ |
363 | | - endpoint = "v1/poet/score" |
364 | | - |
365 | | - if len(queries) == 0: |
366 | | - raise MissingParameterError("Must include queries for scoring!") |
367 | | - if not prompt_id: |
368 | | - raise MissingParameterError("Must include prompt_id in request!") |
369 | | - |
370 | | - queries_bytes = [i.encode() if isinstance(i, str) else i for i in queries] |
371 | | - try: |
372 | | - variant_file = io.BytesIO(b"\n".join(queries_bytes)) |
373 | | - params = {"prompt_id": prompt_id} |
374 | | - response = session.post( |
375 | | - endpoint, files={"variant_file": variant_file}, params=params |
376 | | - ) |
377 | | - return Job.model_validate(response.json()) |
378 | | - except Exception as exc: |
379 | | - raise APIError(f"Failed to post poet score: {exc}") from exc |
380 | | - |
381 | | - |
382 | | -def poet_score_get( |
383 | | - session: APISession, job_id, page_size=config.POET_PAGE_SIZE, page_offset=0 |
384 | | -) -> Job: |
385 | | - """ |
386 | | - Fetch a page of results from a PoET score job. |
387 | | -
|
388 | | - Parameters |
389 | | - ---------- |
390 | | - session : APISession |
391 | | - An instance of APISession to manage interactions with the API. |
392 | | - job_id : str |
393 | | - The ID of the PoET scoring job to fetch results from. |
394 | | - page_size : int, optional |
395 | | - The number of results to fetch in a single page. Defaults to config.POET_PAGE_SIZE. |
396 | | - page_offset : int, optional |
397 | | - The offset (number of results) to start fetching results from. Defaults to 0. |
398 | | -
|
399 | | - Raises |
400 | | - ------ |
401 | | - APIError |
402 | | - If the provided page size is larger than the maximum allowed page size. |
403 | | -
|
404 | | - Returns |
405 | | - ------- |
406 | | - Job |
407 | | - An object representing the PoET scoring job, including its current status and results (if any). |
408 | | - """ |
409 | | - endpoint = "v1/poet/score" |
410 | | - |
411 | | - if page_size > config.POET_MAX_PAGE_SIZE: |
412 | | - raise APIError( |
413 | | - f"Page size must be less than the max for PoET: {config.POET_MAX_PAGE_SIZE}" |
414 | | - ) |
415 | | - |
416 | | - response = session.get( |
417 | | - endpoint, |
418 | | - params={"job_id": job_id, "page_size": page_size, "page_offset": page_offset}, |
419 | | - ) |
420 | | - |
421 | | - return Job.model_validate(response.json()) |
0 commit comments