diff --git a/docs/concepts/parallel-task.rst b/docs/concepts/parallel-task.rst
index b3898b0e9..a85cff6c2 100644
--- a/docs/concepts/parallel-task.rst
+++ b/docs/concepts/parallel-task.rst
@@ -266,3 +266,10 @@ to set what should be determined to be collected at DAG construction time:
# Then in the driver building pass in the configuration:
.with_config(_config)
+
+Parallelizable Subclassing
+==========================
+
+When annotating a function with `Parallelizable`, it is not possible to specify in the annotation what the type returned by the function will actually be, and these are not identified by a linter or other tools as static type checking. Especially for functions that can be used with or without Hamilton, this can be a problem.
+
+To solve this problem, it is possible to create subclasses of the `Parallelizable` classes. The ["Parallelizable Subclass" example](https://github.com/dagworks-inc/hamilton/blob/main/examples/parallelism/parallelizable_subclass) showcases how to do that.
diff --git a/examples/parallelism/parallelizable_subclass/README.md b/examples/parallelism/parallelizable_subclass/README.md
new file mode 100644
index 000000000..69ffa9ae3
--- /dev/null
+++ b/examples/parallelism/parallelizable_subclass/README.md
@@ -0,0 +1,11 @@
+# Parallelizable Subclass
+
+## Overview
+
+When annotating a function with `Parallelizable`, it is not possible to specify in the annotation what the type returned by the function will actually be, and these are not identified by a linter or other tools as static type checking. Especially for functions that can be used with or without Hamilton, this can be a problem.
+
+To solve this problem, it is possible to create subclasses of the `Parallelizable` classes, as demonstrated in this example.
+
+## Running
+
+The `notebook.ipynb` exemplifies how to use a `Parallelizable` subclass.
diff --git a/examples/parallelism/parallelizable_subclass/functions.py b/examples/parallelism/parallelizable_subclass/functions.py
new file mode 100644
index 000000000..a3cae5992
--- /dev/null
+++ b/examples/parallelism/parallelizable_subclass/functions.py
@@ -0,0 +1,15 @@
+from parallelizable_list import ParallelizableList
+
+from hamilton.htypes import Collect
+
+
+def hello_list() -> ParallelizableList[str]:
+ return ["h", "e", "l", "l", "o", " ", "l", "i", "s", "t"]
+
+
+def uppercase(hello_list: str) -> str:
+ return hello_list.upper()
+
+
+def hello_uppercase(uppercase: Collect[str]) -> str:
+ return "".join(uppercase)
diff --git a/examples/parallelism/parallelizable_subclass/notebook.ipynb b/examples/parallelism/parallelizable_subclass/notebook.ipynb
new file mode 100644
index 000000000..b070a38f6
--- /dev/null
+++ b/examples/parallelism/parallelizable_subclass/notebook.ipynb
@@ -0,0 +1,299 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#Install Hamilton if not avaiable\n",
+ "\n",
+ "try:\n",
+ " import hamilton\n",
+ "except ModuleNotFoundError:\n",
+ " %pip install sf-hamilton"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Parallelism: Paralellizable Subclass [](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/parallelism/parallelizable_subclass/notebook.ipynb) [](https://github.com/dagworks-inc/hamilton/blob/main/examples/parallelism/parallelizable_subclass/notebook.ipynb)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "When annotating a function with `Parallelizable`, it is not possible to specify in the annotation what the type returned by the function will actually be, and these are not identified by a linter or other tools as static type checking. Especially for functions that can be used with or without Hamilton, this can be a problem.\n",
+ "\n",
+ "To solve this problem, it is possible to create subclasses of the `Parallelizable` classes, as demonstrated in this example."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We start by importing Hamilton and the created example functions:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from hamilton import driver\n",
+ "\n",
+ "import functions"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Creating a driver and displaing all the module functions, we can see the `hello_list` function, that returns a `ParallelizableList`. This is a example `Parallelizable` subclass created for annotate functions that returns `list`. Is important to note that all `Parallelizable` subclasses must return a `Iterable` subclass, as for example list.\n",
+ "\n",
+ "The `ParallelizableList` implementation can be found in the [\"parallelizable_list.py\" file](https://github.com/dagworks-inc/hamilton/blob/main/examples/parallelism/parallelizable_subclass/parallelizable_list.py)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dr = (\n",
+ " driver.Builder()\n",
+ " .with_modules(functions)\n",
+ " .enable_dynamic_execution(allow_experimental_mode=True)\n",
+ " .build()\n",
+ " )\n",
+ "\n",
+ "dr.display_all_functions()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this simple example, the created flow generates a list with \"hello list\" letters, converts each letter to uppercase in parallel, and then joins the letters together:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'hello_uppercase': 'HELLO LIST'}"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dr.execute([\"hello_uppercase\"])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Focusing attention on the function that was annotated with ParallelizableList, running it manually we can see that it actually returns a list:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['h', 'e', 'l', 'l', 'o', ' ', 'l', 'i', 's', 't']"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "functions.hello_list()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Checking the annotation, we can see the return annotation as \"ParallelizableList[str]\":"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'return': parallelizable_list.ParallelizableList[str]}"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "functions.hello_list.__annotations__"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "And, the key point of using subtypes of `Parallelizable`, it is considered a list instance:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "issubclass(functions.hello_list.__annotations__[\"return\"], list)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This means that when using a linter or static type checking, it will correctly identify the return type as a list instance."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/parallelism/parallelizable_subclass/parallelizable_list.py b/examples/parallelism/parallelizable_subclass/parallelizable_list.py
new file mode 100644
index 000000000..21d9f9d0c
--- /dev/null
+++ b/examples/parallelism/parallelizable_subclass/parallelizable_list.py
@@ -0,0 +1,16 @@
+from typing import Generic, List
+
+from hamilton.htypes import Parallelizable, ParallelizableElement
+
+
+class ParallelizableList(
+ List[ParallelizableElement], Parallelizable, Generic[ParallelizableElement]
+):
+ """
+ Marks the output of a function node as parallelizable and also as a list.
+
+ It has the same usage as "Parallelizable", but for returns that are specifically
+ lists, for correct functioning of linters and other tools.
+ """
+
+ pass
diff --git a/examples/validate_examples.py b/examples/validate_examples.py
index d85dcbbf7..e7eabfa0e 100644
--- a/examples/validate_examples.py
+++ b/examples/validate_examples.py
@@ -15,13 +15,13 @@
def _create_github_badge(path: pathlib.Path) -> str:
- github_url = f"https://github.com/dagworks-inc/hamilton/blob/main/{path}"
+ github_url = f"https://github.com/dagworks-inc/hamilton/blob/main/{path.as_posix()}"
github_badge = f"[]({github_url})"
return github_badge
def _create_colab_badge(path: pathlib.Path) -> str:
- colab_url = f"https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/{path}"
+ colab_url = f"https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/{path.as_posix()}"
colab_badge = (
f"[]({colab_url})"
)
diff --git a/hamilton/htypes.py b/hamilton/htypes.py
index c2ac13314..623cdea18 100644
--- a/hamilton/htypes.py
+++ b/hamilton/htypes.py
@@ -1,7 +1,8 @@
import inspect
import sys
import typing
-from typing import Any, Iterable, Optional, Protocol, Tuple, Type, TypeVar, Union
+from typing import (
+ Any, Iterable, Optional, Protocol, Tuple, Type, TypeVar, Union)
import typing_inspect
@@ -68,7 +69,7 @@ def custom_subclass_check(requested_type: Type, param_type: Type):
has_generic = True
# TODO -- consider moving into a graph adapter or elsewhere -- this is perhaps a little too
# low-level
- if has_generic and requested_origin_type in (Parallelizable,):
+ if has_generic and is_parallelizable(requested_origin_type):
(requested_type_arg,) = _get_args(requested_type)
return custom_subclass_check(requested_type_arg, param_type)
if has_generic and param_origin_type == Collect:
@@ -298,6 +299,19 @@ class Parallelizable(Iterable[ParallelizableElement], Protocol[ParallelizableEle
pass
+def is_parallelizable(type: Type) -> bool:
+ """
+ Checks if a type is parallelizable.
+
+ :param type: Type to check.
+ :return: True if the type is parallelizable, False otherwise.
+ """
+ if type is None:
+ return False
+
+ return type == Parallelizable or Parallelizable in type.__bases__
+
+
def is_parallelizable_type(type_: Type) -> bool:
return _get_origin(type_) == Parallelizable
diff --git a/hamilton/node.py b/hamilton/node.py
index 5755b0ed7..bc16ca8b7 100644
--- a/hamilton/node.py
+++ b/hamilton/node.py
@@ -6,7 +6,7 @@
import typing_inspect
-from hamilton.htypes import Collect, Parallelizable
+from hamilton.htypes import Collect, is_parallelizable
"""
Module that contains the primitive components of the graph.
@@ -285,7 +285,7 @@ def from_fn(fn: Callable, name: str = None) -> "Node":
node_source = NodeType.STANDARD
# TODO - extract this into a function + clean up!
if typing_inspect.is_generic_type(return_type):
- if typing_inspect.get_origin(return_type) == Parallelizable:
+ if is_parallelizable(typing_inspect.get_origin(return_type)):
node_source = NodeType.EXPAND
for hint in typing.get_type_hints(fn, **type_hint_kwargs).values():
if typing_inspect.is_generic_type(hint):