diff --git a/docs/concepts/parallel-task.rst b/docs/concepts/parallel-task.rst index b3898b0e9..a85cff6c2 100644 --- a/docs/concepts/parallel-task.rst +++ b/docs/concepts/parallel-task.rst @@ -266,3 +266,10 @@ to set what should be determined to be collected at DAG construction time: # Then in the driver building pass in the configuration: .with_config(_config) + +Parallelizable Subclassing +========================== + +When annotating a function with `Parallelizable`, it is not possible to specify in the annotation what the type returned by the function will actually be, and these are not identified by a linter or other tools as static type checking. Especially for functions that can be used with or without Hamilton, this can be a problem. + +To solve this problem, it is possible to create subclasses of the `Parallelizable` classes. The ["Parallelizable Subclass" example](https://github.com/dagworks-inc/hamilton/blob/main/examples/parallelism/parallelizable_subclass) showcases how to do that. diff --git a/examples/parallelism/parallelizable_subclass/README.md b/examples/parallelism/parallelizable_subclass/README.md new file mode 100644 index 000000000..69ffa9ae3 --- /dev/null +++ b/examples/parallelism/parallelizable_subclass/README.md @@ -0,0 +1,11 @@ +# Parallelizable Subclass + +## Overview + +When annotating a function with `Parallelizable`, it is not possible to specify in the annotation what the type returned by the function will actually be, and these are not identified by a linter or other tools as static type checking. Especially for functions that can be used with or without Hamilton, this can be a problem. + +To solve this problem, it is possible to create subclasses of the `Parallelizable` classes, as demonstrated in this example. + +## Running + +The `notebook.ipynb` exemplifies how to use a `Parallelizable` subclass. diff --git a/examples/parallelism/parallelizable_subclass/functions.py b/examples/parallelism/parallelizable_subclass/functions.py new file mode 100644 index 000000000..a3cae5992 --- /dev/null +++ b/examples/parallelism/parallelizable_subclass/functions.py @@ -0,0 +1,15 @@ +from parallelizable_list import ParallelizableList + +from hamilton.htypes import Collect + + +def hello_list() -> ParallelizableList[str]: + return ["h", "e", "l", "l", "o", " ", "l", "i", "s", "t"] + + +def uppercase(hello_list: str) -> str: + return hello_list.upper() + + +def hello_uppercase(uppercase: Collect[str]) -> str: + return "".join(uppercase) diff --git a/examples/parallelism/parallelizable_subclass/notebook.ipynb b/examples/parallelism/parallelizable_subclass/notebook.ipynb new file mode 100644 index 000000000..b070a38f6 --- /dev/null +++ b/examples/parallelism/parallelizable_subclass/notebook.ipynb @@ -0,0 +1,299 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "#Install Hamilton if not avaiable\n", + "\n", + "try:\n", + " import hamilton\n", + "except ModuleNotFoundError:\n", + " %pip install sf-hamilton" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Parallelism: Paralellizable Subclass [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/parallelism/parallelizable_subclass/notebook.ipynb) [![GitHub badge](https://img.shields.io/badge/github-view_source-2b3137?logo=github)](https://github.com/dagworks-inc/hamilton/blob/main/examples/parallelism/parallelizable_subclass/notebook.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When annotating a function with `Parallelizable`, it is not possible to specify in the annotation what the type returned by the function will actually be, and these are not identified by a linter or other tools as static type checking. Especially for functions that can be used with or without Hamilton, this can be a problem.\n", + "\n", + "To solve this problem, it is possible to create subclasses of the `Parallelizable` classes, as demonstrated in this example." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start by importing Hamilton and the created example functions:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from hamilton import driver\n", + "\n", + "import functions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Creating a driver and displaing all the module functions, we can see the `hello_list` function, that returns a `ParallelizableList`. This is a example `Parallelizable` subclass created for annotate functions that returns `list`. Is important to note that all `Parallelizable` subclasses must return a `Iterable` subclass, as for example list.\n", + "\n", + "The `ParallelizableList` implementation can be found in the [\"parallelizable_list.py\" file](https://github.com/dagworks-inc/hamilton/blob/main/examples/parallelism/parallelizable_subclass/parallelizable_list.py)." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "hello_uppercase\n", + "\n", + "\n", + "hello_uppercase\n", + "str\n", + "\n", + "\n", + "\n", + "uppercase\n", + "\n", + "uppercase\n", + "str\n", + "\n", + "\n", + "\n", + "uppercase->hello_uppercase\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "hello_list\n", + "\n", + "\n", + "hello_list\n", + "ParallelizableList\n", + "\n", + "\n", + "\n", + "hello_list->uppercase\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "expand\n", + "\n", + "\n", + "expand\n", + "\n", + "\n", + "\n", + "collect\n", + "\n", + "\n", + "collect\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dr = (\n", + " driver.Builder()\n", + " .with_modules(functions)\n", + " .enable_dynamic_execution(allow_experimental_mode=True)\n", + " .build()\n", + " )\n", + "\n", + "dr.display_all_functions()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this simple example, the created flow generates a list with \"hello list\" letters, converts each letter to uppercase in parallel, and then joins the letters together:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'hello_uppercase': 'HELLO LIST'}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dr.execute([\"hello_uppercase\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Focusing attention on the function that was annotated with ParallelizableList, running it manually we can see that it actually returns a list:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['h', 'e', 'l', 'l', 'o', ' ', 'l', 'i', 's', 't']" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "functions.hello_list()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Checking the annotation, we can see the return annotation as \"ParallelizableList[str]\":" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'return': parallelizable_list.ParallelizableList[str]}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "functions.hello_list.__annotations__" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And, the key point of using subtypes of `Parallelizable`, it is considered a list instance:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "issubclass(functions.hello_list.__annotations__[\"return\"], list)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This means that when using a linter or static type checking, it will correctly identify the return type as a list instance." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/parallelism/parallelizable_subclass/parallelizable_list.py b/examples/parallelism/parallelizable_subclass/parallelizable_list.py new file mode 100644 index 000000000..21d9f9d0c --- /dev/null +++ b/examples/parallelism/parallelizable_subclass/parallelizable_list.py @@ -0,0 +1,16 @@ +from typing import Generic, List + +from hamilton.htypes import Parallelizable, ParallelizableElement + + +class ParallelizableList( + List[ParallelizableElement], Parallelizable, Generic[ParallelizableElement] +): + """ + Marks the output of a function node as parallelizable and also as a list. + + It has the same usage as "Parallelizable", but for returns that are specifically + lists, for correct functioning of linters and other tools. + """ + + pass diff --git a/examples/validate_examples.py b/examples/validate_examples.py index d85dcbbf7..e7eabfa0e 100644 --- a/examples/validate_examples.py +++ b/examples/validate_examples.py @@ -15,13 +15,13 @@ def _create_github_badge(path: pathlib.Path) -> str: - github_url = f"https://github.com/dagworks-inc/hamilton/blob/main/{path}" + github_url = f"https://github.com/dagworks-inc/hamilton/blob/main/{path.as_posix()}" github_badge = f"[![GitHub badge](https://img.shields.io/badge/github-view_source-2b3137?logo=github)]({github_url})" return github_badge def _create_colab_badge(path: pathlib.Path) -> str: - colab_url = f"https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/{path}" + colab_url = f"https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/{path.as_posix()}" colab_badge = ( f"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)]({colab_url})" ) diff --git a/hamilton/htypes.py b/hamilton/htypes.py index c2ac13314..623cdea18 100644 --- a/hamilton/htypes.py +++ b/hamilton/htypes.py @@ -1,7 +1,8 @@ import inspect import sys import typing -from typing import Any, Iterable, Optional, Protocol, Tuple, Type, TypeVar, Union +from typing import ( + Any, Iterable, Optional, Protocol, Tuple, Type, TypeVar, Union) import typing_inspect @@ -68,7 +69,7 @@ def custom_subclass_check(requested_type: Type, param_type: Type): has_generic = True # TODO -- consider moving into a graph adapter or elsewhere -- this is perhaps a little too # low-level - if has_generic and requested_origin_type in (Parallelizable,): + if has_generic and is_parallelizable(requested_origin_type): (requested_type_arg,) = _get_args(requested_type) return custom_subclass_check(requested_type_arg, param_type) if has_generic and param_origin_type == Collect: @@ -298,6 +299,19 @@ class Parallelizable(Iterable[ParallelizableElement], Protocol[ParallelizableEle pass +def is_parallelizable(type: Type) -> bool: + """ + Checks if a type is parallelizable. + + :param type: Type to check. + :return: True if the type is parallelizable, False otherwise. + """ + if type is None: + return False + + return type == Parallelizable or Parallelizable in type.__bases__ + + def is_parallelizable_type(type_: Type) -> bool: return _get_origin(type_) == Parallelizable diff --git a/hamilton/node.py b/hamilton/node.py index 5755b0ed7..bc16ca8b7 100644 --- a/hamilton/node.py +++ b/hamilton/node.py @@ -6,7 +6,7 @@ import typing_inspect -from hamilton.htypes import Collect, Parallelizable +from hamilton.htypes import Collect, is_parallelizable """ Module that contains the primitive components of the graph. @@ -285,7 +285,7 @@ def from_fn(fn: Callable, name: str = None) -> "Node": node_source = NodeType.STANDARD # TODO - extract this into a function + clean up! if typing_inspect.is_generic_type(return_type): - if typing_inspect.get_origin(return_type) == Parallelizable: + if is_parallelizable(typing_inspect.get_origin(return_type)): node_source = NodeType.EXPAND for hint in typing.get_type_hints(fn, **type_hint_kwargs).values(): if typing_inspect.is_generic_type(hint):