Skip to content

add functionality required for this test load_collection & collect #629

@github-actions

Description

@github-actions

add functionality required for this test load_collection & collect

# TODO: add functionality required for this test load_collection & collect

#  limitations under the License.
#

from pywy.dataquanta import WayangContext
from pywy.platforms.java import JavaPlugin
from pywy.platforms.spark import SparkPlugin
import pytest

# TODO: add functionality required for this test load_collection & collect
@pytest.mark.skip(reason="no way of currently testing this, since we are missing implementations for load_collection & collect")
def test_train_and_predict():
    # Initialize context with platforms
    ctx = WayangContext().register({JavaPlugin, SparkPlugin})

    # Input features and labels
    features = ctx.load_collection([
        [1.0, 2.0],
        [2.0, 3.0],
        [3.0, 4.0],
        [4.0, 5.0]
    ])
    labels = ctx.load_collection([3.0, 4.0, 5.0, 6.0])

    # Train the model
    model = features.train_decision_tree_regression(labels, max_depth=3, min_instances=1)

    # Run predictions on same features
    predictions = model.predict(features)

    # Collect and validate
    result = predictions.collect()
    print("Predictions:", result)

    assert len(result) is 4, f"Expected len(result) to be 4, but got: {len(result)}"
    for pred in result:
        assert pred is float
        assert pred > 1.0
        assert pred <= 7.0

fac400e04a14d77f47335d07ed4a704d15f02e42

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions