diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 000000000..ada66361d --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,24 @@ +name: lint +on: + push +jobs: + + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Install uv + uses: astral-sh/setup-uv@v5 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '>=1.22' + - name: Install dependencies + run: make setup-lint + - name: Run lint + run: . venv/bin/activate && make lint diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8a0a9fde6..bd3bf6b57 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -13,7 +13,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: '3.11' - name: Set up Go uses: actions/setup-go@v5 with: @@ -31,7 +31,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v2 with: - python-version: '3.10' + python-version: '3.11' - name: Set up Docker Buildx uses: docker/setup-buildx-action@v2 diff --git a/Makefile b/Makefile index 223e94100..3a1ac8518 100644 --- a/Makefile +++ b/Makefile @@ -74,19 +74,31 @@ setup-test-e2e: ( cd racetrack_commons && make setup ) @echo Activate your venv: . venv/bin/activate +setup-lint: + uv venv venv &&\ + . venv/bin/activate &&\ + uv pip install -r requirements-test.txt -r requirements-dev.txt \ + -r racetrack_client/requirements.txt -e racetrack_client@racetrack_client \ + -r racetrack_commons/requirements.txt -e racetrack_commons@racetrack_commons \ + -r lifecycle/requirements.txt -e lifecycle@lifecycle \ + -r image_builder/requirements.txt -e image_builder@image_builder \ + -r dashboard/requirements.txt -e dashboard@dashboard + @echo Activate your venv: . venv/bin/activate + install-racetrack-client: ( cd racetrack_client && pip install -e . ) lint: - -python -m mypy --ignore-missing-imports --exclude 'racetrack_client/build' racetrack_client - -python -m mypy --ignore-missing-imports racetrack_commons - -python -m mypy --ignore-missing-imports --exclude 'lifecycle/lifecycle/django/registry/migrations' lifecycle - -python -m mypy --ignore-missing-imports image_builder - -python -m mypy --ignore-missing-imports dashboard - -python -m flake8 --ignore E501 --per-file-ignores="__init__.py:F401" \ - lifecycle image_builder dashboard - -python -m pylint --disable=R,C,W \ - lifecycle/lifecycle image_builder/image_builder dashboard/dashboard + python -m mypy --ignore-missing-imports --exclude 'racetrack_client/build' racetrack_client; e1=$$?;\ + python -m mypy --ignore-missing-imports racetrack_commons; e2=$$?;\ + python -m mypy --ignore-missing-imports --exclude 'lifecycle/lifecycle/django/registry/migrations' lifecycle; e3=$$?;\ + python -m mypy --ignore-missing-imports image_builder; e4=$$?;\ + python -m mypy --ignore-missing-imports dashboard; e5=$$?;\ + python -m flake8 --ignore E501 --per-file-ignores="__init__.py:F401 lifecycle/lifecycle/event_stream/server.py:E402" \ + lifecycle image_builder dashboard; e6=$$?;\ + python -m pylint --disable=R,C,W \ + lifecycle/lifecycle image_builder/image_builder dashboard/dashboard; e7=$$?;\ + exit "$$(( e1 || e2 || e3 || e4 || e5 || e6 || e7 ))" format: python -m black -S --diff --color -l 120 \ diff --git a/dashboard/dashboard/api/api.py b/dashboard/dashboard/api/api.py index 5bc8dcb21..8809bd3ce 100644 --- a/dashboard/dashboard/api/api.py +++ b/dashboard/dashboard/api/api.py @@ -1,7 +1,7 @@ import os import httpx -from fastapi import FastAPI, Request, Response, Body +from fastapi import FastAPI, Request, Body from starlette.background import BackgroundTask from starlette.responses import StreamingResponse from starlette.datastructures import MutableHeaders @@ -30,7 +30,7 @@ def _status(): 'grafana_url': get_external_grafana_url(), 'site_name': os.environ.get('SITE_NAME', ''), } - + setup_docs_endpoints(app) setup_proxy_endpoints(app) @@ -40,7 +40,7 @@ def setup_proxy_endpoints(app: FastAPI): lifecycle_api_url = trim_url(os.environ.get('LIFECYCLE_URL', 'http://127.0.0.1:7202')) logger.info(f'Forwarding API requests to "{lifecycle_api_url}"') client = httpx.AsyncClient(base_url=f"{lifecycle_api_url}/") - + async def _proxy_api_call(request: Request, path: str, payload=Body(default={})): """Forward API call to Lifecycle service""" subpath = f'/api/v1/{request.path_params["path"]}' diff --git a/dashboard/dashboard/api/docs.py b/dashboard/dashboard/api/docs.py index 11827f1af..e6eb5be33 100644 --- a/dashboard/dashboard/api/docs.py +++ b/dashboard/dashboard/api/docs.py @@ -35,7 +35,7 @@ def _get_docs_index() -> dict: 'doc_pages': sorted(doc_pages, key=lambda x: x['title'].lower()), 'plugin_pages': sorted(plugin_pages, key=lambda x: x['title'].lower()), } - + @app.get('/api/docs/page/{doc_path:path}') def _get_docs_page(doc_path: str) -> dict: docs_path = _get_docs_root_dir() @@ -51,7 +51,7 @@ def _get_docs_page(doc_path: str) -> dict: 'doc_name': doc_path, 'html_content': html, } - + @app.get('/api/docs/plugin/{plugin_name}') def _get_docs_plugin_page(plugin_name: str) -> dict: plugin_client = LifecyclePluginClient() diff --git a/dashboard/dashboard/api/webview.py b/dashboard/dashboard/api/webview.py index 538dc737a..1495501e8 100644 --- a/dashboard/dashboard/api/webview.py +++ b/dashboard/dashboard/api/webview.py @@ -13,11 +13,11 @@ def setup_web_views(app: FastAPI): @app.get("/", tags=['ui']) def _ui_root_view(): return RedirectResponse('/dashboard/ui/') - + @app.get("/ui", tags=['ui']) def _ui_home_view(): return FileResponse('static/index.html') - + @app.get("/ui/", tags=['ui']) def _ui_home_slash_view(): return FileResponse('static/index.html') @@ -26,7 +26,7 @@ def _ui_home_slash_view(): app.mount("/ui/assets", StaticFiles(directory="static/assets/"), name="static_front") else: logger.warning('No static/assets directory found') - + @app.get("/ui/{subpath:path}", tags=['ui']) def _ui_home_any_view(subpath: str): return FileResponse('static/index.html') diff --git a/image_builder/image_builder/api.py b/image_builder/image_builder/api.py index 071e14b84..961666e9a 100644 --- a/image_builder/image_builder/api.py +++ b/image_builder/image_builder/api.py @@ -9,7 +9,6 @@ from image_builder.health import health_response from image_builder.scheduler import schedule_tasks_async from racetrack_client.client_config.io import load_credentials_from_dict -from racetrack_client.log.exception import log_exception from racetrack_client.log.logs import configure_logs from racetrack_client.manifest.load import load_manifest_from_dict from racetrack_client.utils.config import load_config diff --git a/image_builder/image_builder/build.py b/image_builder/image_builder/build.py index 8165feced..71e4f3224 100644 --- a/image_builder/image_builder/build.py +++ b/image_builder/image_builder/build.py @@ -18,7 +18,6 @@ from image_builder.progress import update_deployment_phase from image_builder.verify import verify_manifest_consistency from image_builder.warnings import update_deployment_warnings -from racetrack_client.client.env import merge_env_vars from racetrack_client.client_config.client_config import Credentials from racetrack_client.log.context_error import wrap_context from racetrack_client.log.logs import get_logger diff --git a/image_builder/image_builder/docker/builder.py b/image_builder/image_builder/docker/builder.py index 33d473f1c..ce81c4b32 100644 --- a/image_builder/image_builder/docker/builder.py +++ b/image_builder/image_builder/docker/builder.py @@ -57,7 +57,7 @@ def build( built_images: list[str] = [] images_num = job_type.images_num for image_index in range(images_num): - progress = f'({image_index+1}/{images_num})' + progress = f'({image_index + 1}/{images_num})' try: _build_job_image( @@ -97,7 +97,7 @@ def _build_job_image( built_images: list[str], build_flags: list[str], ): - progress = f'({image_index+1}/{images_num})' + progress = f'({image_index + 1}/{images_num})' image_def: JobTypeImageDef = job_type.images[image_index] with phase_context(f'building image {progress}', metric_labels, deployment_id, config, metric_phase='building image'): @@ -161,7 +161,7 @@ def _build_container_image( 'DOCKER_BUILDKIT=1 docker buildx build', f'-t {image_name}', f'-f {dockerfile_path}', - f'--network=host', + '--network=host', f'--build-context {JOBTYPE_BUILD_CONTEXT}="{jobtype_dir}"', ] cmd_parts.extend(build_flags) diff --git a/image_builder/image_builder/verify.py b/image_builder/image_builder/verify.py index 1c7e8956d..42042a334 100644 --- a/image_builder/image_builder/verify.py +++ b/image_builder/image_builder/verify.py @@ -42,11 +42,14 @@ def verify_manifest_consistency(submitted_yaml: str, workspace: Path, repo_dir: logger.info(f'Manifest file in Job repository:\n{repo_dict}') difference = _differentiate_dicts(repo_dict, submitted_dict) warning = ('Submitted job manifest is not consistent with the file found in a repository. ' - 'Did you forget to do "git push"? ' - f'Difference: {difference}') + 'Did you forget to do "git push"? ' + f'Difference: {difference}') logger.warning(warning) return warning + return None + + def _find_workspace_manifest_file(workspace: Path, repo_dir: Path) -> Optional[Path]: paths_to_check = [ workspace / JOB_MANIFEST_FILENAME, diff --git a/lifecycle/lifecycle/auth/authenticate_password.py b/lifecycle/lifecycle/auth/authenticate_password.py index 7f0613b89..4db8add4d 100644 --- a/lifecycle/lifecycle/auth/authenticate_password.py +++ b/lifecycle/lifecycle/auth/authenticate_password.py @@ -7,6 +7,7 @@ UNUSABLE_PASSWORD_SUFFIX_LENGTH = 40 UNUSABLE_PASSWORD_PREFIX = "!" + def authenticate(username: str, password: str) -> Optional[User]: """ If the given credentials are valid, return a User object. @@ -18,17 +19,22 @@ def authenticate(username: str, password: str) -> Optional[User]: return user + def authenticate_user(username: str, password: str) -> Optional[User]: - try: - user: User = LifecycleCache.record_mapper().find_one(User, username=username) - except EntityNotFound: - # Run the default password hasher once to reduce the timing - # difference between an existing and a nonexistent user - - make_password(password) - else: - if check_password(password, user.password) and user.is_active: - return user + try: + user: User = LifecycleCache.record_mapper().find_one(User, username=username) + except EntityNotFound: + # Run the default password hasher once to reduce the timing + # difference between an existing and a nonexistent user + + make_password(password) + + return None + + if check_password(password, user.password) and user.is_active: + return user + + return None class PermissionDenied(Exception): diff --git a/lifecycle/lifecycle/auth/authorize.py b/lifecycle/lifecycle/auth/authorize.py index 30ec662a1..1787a1000 100644 --- a/lifecycle/lifecycle/auth/authorize.py +++ b/lifecycle/lifecycle/auth/authorize.py @@ -243,7 +243,7 @@ def list_permitted_jobs( family_to_job_ids[job.name].append(job.id) job_ids = set() - for permission in permissions: + for permission in permissions: if permission.job_family_id is None and permission.job_id is None: return all_jobs diff --git a/lifecycle/lifecycle/auth/hasher.py b/lifecycle/lifecycle/auth/hasher.py index ced33ad1c..f9fdc80ef 100644 --- a/lifecycle/lifecycle/auth/hasher.py +++ b/lifecycle/lifecycle/auth/hasher.py @@ -1,10 +1,17 @@ +from __future__ import annotations + import base64 import functools import hashlib import math -from typing import Optional +from typing import Optional, Protocol import secrets +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from _typeshed import ReadableBuffer + + RANDOM_STRING_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" @@ -14,6 +21,11 @@ def make_password(password: str) -> str: return hasher.encode(password, salt) +class Hash(Protocol): + def __call__(self, string: ReadableBuffer = b"", *, usedforsecurity: bool = True) -> hashlib._Hash: + ... + + class PBKDF2PasswordHasher: """ Secure password hashing using the PBKDF2 algorithm (recommended) @@ -25,10 +37,10 @@ class PBKDF2PasswordHasher: algorithm: str = "pbkdf2_sha256" iterations: int = 600000 - digest: 'hashlib._Hash' = hashlib.sha256 + digest: Hash = hashlib.sha256 salt_entropy: int = 128 - def encode(self, password: str, salt, iterations: Optional[int]=None): + def encode(self, password: str, salt, iterations: Optional[int] = None): iterations = iterations or self.iterations hash = pbkdf2(password, salt, iterations, digest=self.digest) hash = base64.b64encode(hash).decode("ascii").strip() @@ -60,7 +72,7 @@ def salt(self) -> str: return get_random_string(char_count, allowed_chars=RANDOM_STRING_CHARS) -def get_random_string(length: int, allowed_chars: str=RANDOM_STRING_CHARS) -> str: +def get_random_string(length: int, allowed_chars: str = RANDOM_STRING_CHARS) -> str: """ Return a securely generated random string. @@ -74,14 +86,14 @@ def get_random_string(length: int, allowed_chars: str=RANDOM_STRING_CHARS) -> st return "".join(secrets.choice(allowed_chars) for i in range(length)) -def pbkdf2(password: str, salt: str, iterations: int, dklen: int=0, digest: Optional['hashlib._Hash']=None): +def pbkdf2(password: str, salt: str, iterations: int, dklen: int = 0, digest: Optional[Hash] = None): """Return the hash of password using pbkdf2.""" if digest is None: digest = hashlib.sha256 - dklen = dklen or None - password = password.encode("utf-8", "strict") - salt = salt.encode("utf-8", "strict") - return hashlib.pbkdf2_hmac(digest().name, password, salt, iterations, dklen) + + password_bytes = password.encode("utf-8", "strict") + salt_bytes = salt.encode("utf-8", "strict") + return hashlib.pbkdf2_hmac(digest().name, password_bytes, salt_bytes, iterations, dklen or None) def constant_time_compare(val1, val2): diff --git a/lifecycle/lifecycle/auth/users.py b/lifecycle/lifecycle/auth/users.py index 62a5cdb51..bc1bd9ca8 100644 --- a/lifecycle/lifecycle/auth/users.py +++ b/lifecycle/lifecycle/auth/users.py @@ -41,7 +41,7 @@ def register_user_account(username: str, password: str) -> tables.User: pass user = LifecycleCache.record_mapper().create_from_dict( - tables.User, + tables.User, { "username": username, "password": make_password(password), @@ -66,7 +66,7 @@ def register_user_account(username: str, password: str) -> tables.User: grant_permission(auth_subject, AuthScope.DEPLOY_JOB.value) logger.info(f'User account created: {username}') - return user + return user_record def change_user_password(username: str, old_password: str, new_password: str): diff --git a/lifecycle/lifecycle/auth/validate.py b/lifecycle/lifecycle/auth/validate.py index 143c3063b..7558bff98 100644 --- a/lifecycle/lifecycle/auth/validate.py +++ b/lifecycle/lifecycle/auth/validate.py @@ -1,28 +1,32 @@ from pathlib import Path - - + + def validate_password(password: str): for validator in [validate_password_length, validate_numeric_password, validate_common_password]: validator(password) + def validate_password_length(password: str): if len(password) < 8: raise ValidationError("This password is too short. It must contain at least 8 characters.") - + + def validate_numeric_password(password: str): if password.isdigit(): raise ValidationError("This password is entirely numeric.") - + + def validate_common_password(password: str): # password list based on https://gist.github.com/roycewilliams/226886fd01572964e1431ac8afc999ce # by Royce Williams - passwords_path: str = Path(__file__).resolve().parent / "common-passwords.txt" - + passwords_path: Path = Path(__file__).resolve().parent / "common-passwords.txt" + with open(passwords_path, "rt", encoding="utf-8") as f: passwords = {x.strip() for x in f} if password.lower().strip() in passwords: raise ValidationError("This password is too common.") + class ValidationError(Exception): pass diff --git a/lifecycle/lifecycle/database/base_engine.py b/lifecycle/lifecycle/database/base_engine.py index 3999e5215..b39d423b8 100644 --- a/lifecycle/lifecycle/database/base_engine.py +++ b/lifecycle/lifecycle/database/base_engine.py @@ -31,7 +31,7 @@ class DbEngine(ABC): - SQLite implementation: file://./sqlite/engine.py """ query_builder: BaseQueryBuilder - + def check_connection(self) -> None: pass diff --git a/lifecycle/lifecycle/database/condition_builder.py b/lifecycle/lifecycle/database/condition_builder.py index d097f3b01..4d7815009 100644 --- a/lifecycle/lifecycle/database/condition_builder.py +++ b/lifecycle/lifecycle/database/condition_builder.py @@ -10,7 +10,7 @@ def __init__(self, expression: str, *params): @staticmethod def empty() -> 'QueryCondition': return QueryCondition('', []) - + def is_empty(self) -> bool: return self.expression == '' @@ -39,7 +39,7 @@ def filter_conditions(self) -> list[str] | None: if self.is_empty(): return None return [self.expression] - + @property def filter_params(self) -> list[Any] | None: return self.params if self.params else None diff --git a/lifecycle/lifecycle/database/postgres/engine.py b/lifecycle/lifecycle/database/postgres/engine.py index 35ff5c8f4..f1b5023d2 100644 --- a/lifecycle/lifecycle/database/postgres/engine.py +++ b/lifecycle/lifecycle/database/postgres/engine.py @@ -41,7 +41,7 @@ def __init__(self, max_pool_size: int, log_queries: bool): name='lifecycle', # name to give to the pool, useful, for instance, to identify it in the logs timeout=5, # The default maximum time in seconds that a client can wait to receive a connection from the pool max_waiting=0, # Maximum number of requests that can be queued to the pool, after which new requests will fail, raising TooManyRequests. 0 means no queue limit. - max_lifetime=10*60, # The maximum lifetime of a connection in the pool, in seconds. Connections used for longer get closed and replaced by a new one. + max_lifetime=10 * 60, # The maximum lifetime of a connection in the pool, in seconds. Connections used for longer get closed and replaced by a new one. max_idle=60, # Maximum time, in seconds, that a connection can stay unused in the pool before being closed, and the pool shrunk reconnect_timeout=5, # Maximum time, in seconds, the pool will try to create a connection. If a connection attempt fails, the pool will try to reconnect a few times, using an exponential backoff and some random factor to avoid mass attempts. If repeated attempts fail, after reconnect_timeout second the connection attempt is aborted and the reconnect_failed() callback invoked reconnect_failed=self._on_reconnect_failed, # Callback invoked if an attempt to create a new connection fails for more than reconnect_timeout seconds @@ -77,7 +77,7 @@ def _on_reconnect_failed(self, _: ConnectionPool) -> None: def _on_reset_connection(self, _: Connection) -> None: metric_database_connection_closed.inc() - + def check_connection(self) -> None: try: conn_params = get_connection_params() @@ -221,9 +221,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @classmethod - def connect(cls, *args, **kwargs) -> "Connection": + def connect(cls, *args, **kwargs) -> "PgConnection": try: - return Connection.connect(*args, **kwargs) + return super(PgConnection, cls).connect(*args, **kwargs) except BaseException as e: metric_database_connection_failed.inc() log_exception(ContextError('Connection to database failed', e)) diff --git a/lifecycle/lifecycle/database/record_mapper.py b/lifecycle/lifecycle/database/record_mapper.py index 58448b92b..8b7bbfc9d 100644 --- a/lifecycle/lifecycle/database/record_mapper.py +++ b/lifecycle/lifecycle/database/record_mapper.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Any, Type, TypeVar +from typing import Any, Type, TypeVar, cast from lifecycle.database.base_engine import DbEngine from lifecycle.database.condition_builder import QueryCondition @@ -81,7 +81,7 @@ def find_many( order_by=order_by, ) return [self._convert_row_to_record_model(row, table_type) for row in rows] - + def filter( self, table_type: Type[T], @@ -226,7 +226,7 @@ def exists_record( filter_params=filter_params, ) return row is not None - + def exists_on_condition( self, table_type: Type[T], @@ -447,11 +447,11 @@ def _convert_row_to_record_model( for column in row.keys(): assert column in valid_fields, \ f'retrieved column "{column}" is not a valid field for the model {table_type_name(table_type)}' - record_model = parse_typed_object(row, table_type) + record_model = cast(TableModel, parse_typed_object(row, table_type)) # remember original values to keep track of changed fields setattr(record_model, '_original_fields', record_to_dict(record_model)) - return record_model + return cast(T, record_model) def _extract_record_data(self, record_model: TableModel) -> dict[str, Any]: metadata = self._tables_metadata[type(record_model)] diff --git a/lifecycle/lifecycle/database/schema/tables.py b/lifecycle/lifecycle/database/schema/tables.py index 829f6f06d..9383f1934 100644 --- a/lifecycle/lifecycle/database/schema/tables.py +++ b/lifecycle/lifecycle/database/schema/tables.py @@ -339,7 +339,7 @@ class Metadata: attempts: int pub_instance_addr: str retriable_error: bool - + def __str__(self): return self.id diff --git a/lifecycle/lifecycle/database/table_model.py b/lifecycle/lifecycle/database/table_model.py index 66441eb56..01805ae31 100644 --- a/lifecycle/lifecycle/database/table_model.py +++ b/lifecycle/lifecycle/database/table_model.py @@ -2,7 +2,8 @@ from dataclasses import asdict, is_dataclass from datetime import datetime from enum import Enum -from typing import Any, Type, Callable +from types import UnionType +from typing import Any, Optional, Type, Callable import uuid @@ -53,7 +54,7 @@ def table_type_name(cls: Type[TableModel] | TableModel) -> str: def table_metadata(cls: Type[TableModel] | TableModel) -> TableModel.Metadata: if isinstance(cls, TableModel): cls = type(cls) - metadata: TableModel.Metadata = getattr(cls, 'Metadata', None) + metadata: Optional[TableModel.Metadata] = getattr(cls, 'Metadata', None) assert metadata is not None, f'Metadata class not specified in {cls}' field_annotations: dict[str, type] = cls.__annotations__ @@ -91,8 +92,8 @@ def record_to_dict(self: TableModel) -> dict[str, Any]: raise ValueError(f"'{self.__class__.__name__}' is not a dataclass!") -def build_column_type(annotation: type) -> ColumnType: - type_dict: dict[type, ColumnType] = { +def build_column_type(annotation: type | UnionType) -> ColumnType: + type_dict: dict[type | UnionType, ColumnType] = { str: ColumnType.STRING, datetime: ColumnType.DATETIME, int: ColumnType.INT, diff --git a/lifecycle/lifecycle/database/type_parser.py b/lifecycle/lifecycle/database/type_parser.py index 20909f563..24d3a373c 100644 --- a/lifecycle/lifecycle/database/type_parser.py +++ b/lifecycle/lifecycle/database/type_parser.py @@ -2,15 +2,16 @@ import json from datetime import datetime, timezone from dateutil import parser as dt_parser -from typing import Type, TypeVar, Union, Any, get_origin, get_args +from typing import Type, TypeVar, Union, Any, cast, get_origin, get_args import types from racetrack_client.log.context_error import ContextError + T = TypeVar("T") -def parse_typed_object(obj: Any, clazz: Type[T]) -> T: +def parse_typed_object(obj: Any, clazz: Type[T]) -> T | None: """ Cast object value to its expected type, using annotated types :param obj: object value to be transformed into its expected type @@ -18,31 +19,32 @@ def parse_typed_object(obj: Any, clazz: Type[T]) -> T: """ if obj is None: return None - + # automatic type conversion if type(obj) is str and clazz is datetime: - return dt_parser.parse(obj).replace(tzinfo=timezone.utc) - + return cast(T, dt_parser.parse(obj).replace(tzinfo=timezone.utc)) + if dataclasses.is_dataclass(clazz): assert isinstance(obj, dict), f'expected dict type to parse into a dataclass, got {type(obj)}' field_types = {field.name: field.type for field in dataclasses.fields(clazz)} - dataclass_kwargs = dict() + dataclass_kwargs: dict[Any, Any] = dict() for key, value in obj.items(): if key not in field_types: raise KeyError(f'unexpected field "{key}" provided to type {clazz}') - dataclass_kwargs[key] = parse_typed_object(value, field_types[key]) - return clazz(**dataclass_kwargs) - + + field_type = cast(type, field_types[key]) + + dataclass_kwargs[key] = parse_typed_object(value, field_type) + return cast(T, clazz(**dataclass_kwargs)) + elif get_origin(clazz) in {Union, types.UnionType}: # Union or Optional type union_types = get_args(clazz) left_types = [] for union_type in union_types: if dataclasses.is_dataclass(union_type): - if obj is not None: - return parse_typed_object(obj, union_type) + return parse_typed_object(obj, cast(type[T], union_type)) elif union_type is types.NoneType: - if obj is None: - return None + continue else: left_types.append(union_type) if not left_types: @@ -50,13 +52,13 @@ def parse_typed_object(obj: Any, clazz: Type[T]) -> T: if len(left_types) > 1: raise ValueError(f'too many ambiguous union types {left_types} ({clazz}) matching to a given value: {obj}') return parse_typed_object(obj, left_types[0]) - + elif get_origin(clazz) is None and isinstance(obj, clazz): - return obj - + return cast(T, obj) + else: try: - return clazz(obj) + return clazz(obj) # type: ignore except BaseException as e: raise ValueError(f'failed to parse "{obj}" ({type(obj)}) to type {clazz}: {e}') @@ -76,7 +78,10 @@ def parse_dict_typed_values(data: dict[str, Any], clazz: Type[T]) -> dict[str, A for key, value in data.items(): if key not in field_types: raise KeyError(f'unexpected field "{key}" provided for type {clazz}') - typed_data[key] = parse_typed_object(value, field_types[key]) + + field_type = cast(type, field_types[key]) + + typed_data[key] = parse_typed_object(value, field_type) return typed_data diff --git a/lifecycle/lifecycle/deployer/deployers.py b/lifecycle/lifecycle/deployer/deployers.py index 85dd3e045..e496ee9dd 100644 --- a/lifecycle/lifecycle/deployer/deployers.py +++ b/lifecycle/lifecycle/deployer/deployers.py @@ -6,4 +6,8 @@ def get_job_deployer( infrastructure_name: str | None, ) -> JobDeployer: infra_target = get_infrastructure_target(infrastructure_name) + + if infra_target.job_deployer is None: + raise ValueError("job deployer is None") + return infra_target.job_deployer diff --git a/lifecycle/lifecycle/deployer/infra_target.py b/lifecycle/lifecycle/deployer/infra_target.py deleted file mode 100644 index 9e029637b..000000000 --- a/lifecycle/lifecycle/deployer/infra_target.py +++ /dev/null @@ -1 +0,0 @@ -from lifecycle.infrastructure.model import InfrastructureTarget diff --git a/lifecycle/lifecycle/deployer/redeploy.py b/lifecycle/lifecycle/deployer/redeploy.py index f41ba3844..d89ea5478 100644 --- a/lifecycle/lifecycle/deployer/redeploy.py +++ b/lifecycle/lifecycle/deployer/redeploy.py @@ -20,7 +20,7 @@ def redeploy_job( plugin_engine: PluginEngine, deployer_username: str, auth_subject: tables.AuthSubject | None, - build_flags: list[str] = [], # Setting a default here let's us ignore build_flags when redeploying + build_flags: list[str] = [], # Setting a default here let's us ignore build_flags when redeploying ): """Deploy (rebuild and reprovision) Job once again without knowing secrets""" job = read_job(job_name, job_version, config) diff --git a/lifecycle/lifecycle/django/database/base.py b/lifecycle/lifecycle/django/database/base.py index c46f3f582..ca8c95a55 100644 --- a/lifecycle/lifecycle/django/database/base.py +++ b/lifecycle/lifecycle/django/database/base.py @@ -11,7 +11,7 @@ class DatabaseWrapper(base.DatabaseWrapper): """Subclass of django.db.backends.postgresql.base.DatabaseWrapper, measuring connection statistics""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - + def get_new_connection(self, conn_params): try: connection: Connection = super().get_new_connection(conn_params) diff --git a/lifecycle/lifecycle/django/registry/migrations/0023_default_superadmin.py b/lifecycle/lifecycle/django/registry/migrations/0023_default_superadmin.py index 8433b7c4f..c559f48e1 100644 --- a/lifecycle/lifecycle/django/registry/migrations/0023_default_superadmin.py +++ b/lifecycle/lifecycle/django/registry/migrations/0023_default_superadmin.py @@ -37,7 +37,7 @@ def create_default_superuser(apps, schema_editor): with DisableSignals(): try: user = User.objects.get(username='admin') - except User.DoesNotExist: + except User.DoesNotExist: # pylint: disable=E1101 user = User() user.is_active = True user.is_superuser = True diff --git a/lifecycle/lifecycle/django/registry/models.py b/lifecycle/lifecycle/django/registry/models.py index e7ddf3c23..3e0a0d836 100644 --- a/lifecycle/lifecycle/django/registry/models.py +++ b/lifecycle/lifecycle/django/registry/models.py @@ -249,4 +249,4 @@ class Meta: retriable_error = models.BooleanField() def __str__(self): - return self.id + return str(self.id) diff --git a/lifecycle/lifecycle/endpoints/auth.py b/lifecycle/lifecycle/endpoints/auth.py index 109744f51..9d6e53d9e 100644 --- a/lifecycle/lifecycle/endpoints/auth.py +++ b/lifecycle/lifecycle/endpoints/auth.py @@ -70,7 +70,7 @@ def _auth_allowed_job_endpoint(job_name: str, job_version: str, scope: str, endp return JSONResponse(content={'error': msg}, status_code=401) return Response(content='', status_code=202) - + class JobCallAuthData(BaseModel): job: JobDto caller: str | None = None @@ -93,7 +93,7 @@ def _auth_can_call_job_endpoint( endpoint = _normalize_endpoint_path(endpoint) job_model = models_registry.resolve_job_model(job_name, job_version) - + try: auth_subject = _authorize_job_caller(job_model, endpoint, request) caller: str | None = get_description_from_auth_subject(auth_subject) if auth_subject else None @@ -155,19 +155,19 @@ def _generate_tokens_for_all_users(request: Request): """Generate new tokens for all Users""" check_auth(request, scope=AuthScope.CALL_ADMIN_API) regenerate_all_user_tokens() - + @api.post('/auth/token/job_family/regenerate') def _generate_tokens_for_all_job_families(request: Request): """Generate new tokens for all Job Families""" check_auth(request, scope=AuthScope.CALL_ADMIN_API) regenerate_all_job_family_tokens() - + @api.post('/auth/token/esc/regenerate') def _generate_tokens_for_all_escs(request: Request): """Generate new tokens for all ESCs""" check_auth(request, scope=AuthScope.CALL_ADMIN_API) regenerate_all_esc_tokens() - + def _normalize_endpoint_path(endpoint: str) -> str: if endpoint.endswith('/'): @@ -191,7 +191,7 @@ def _authorize_job_caller(job_model: tables.Job, endpoint: str, request: Request for public_endpoint in public_endpoints: if endpoint.startswith(public_endpoint): return None - + scope = AuthScope.CALL_JOB.value token_payload, auth_subject = authenticate_token(request) diff --git a/lifecycle/lifecycle/endpoints/deploy.py b/lifecycle/lifecycle/endpoints/deploy.py index 625f885b8..adcc7d537 100644 --- a/lifecycle/lifecycle/endpoints/deploy.py +++ b/lifecycle/lifecycle/endpoints/deploy.py @@ -128,7 +128,6 @@ def _update_deployment_phase(deploy_id: str, payload: DeploymentPhase, request: check_auth(request) save_deployment_phase(deploy_id, payload.phase) - class DeploymentWarnings(BaseModel): warnings: str = Field(description='deployment warnings') diff --git a/lifecycle/lifecycle/endpoints/records.py b/lifecycle/lifecycle/endpoints/records.py index f75f85077..325a3d64c 100644 --- a/lifecycle/lifecycle/endpoints/records.py +++ b/lifecycle/lifecycle/endpoints/records.py @@ -167,7 +167,7 @@ def list_table_records(mapper: RecordMapper, payload: FetchManyRecordsRequest, t filters = parse_dict_typed_values(payload.filters or {}, table_type) columns = payload.columns or metadata.fields records: list[dict] = mapper.filter_dicts( - table_type, columns=payload.columns, filters=filters, order_by=payload.order_by, + table_type, columns=payload.columns, filters=filters, order_by=payload.order_by, offset=payload.offset, limit=payload.limit) record_payloads: list[RecordFieldsPayload] = [ RecordFieldsPayload(fields=convert_to_json_serializable(record)) diff --git a/lifecycle/lifecycle/event_stream/server.py b/lifecycle/lifecycle/event_stream/server.py index b9a7a9087..902f3e8ac 100644 --- a/lifecycle/lifecycle/event_stream/server.py +++ b/lifecycle/lifecycle/event_stream/server.py @@ -34,7 +34,7 @@ def __init__(self, config: Config): class WebSocketResource: async def on_websocket(self, _: Request, ws: WebSocket): await ws.accept() - logger.debug(f'Client connected to Event Stream') + logger.debug('Client connected to Event Stream') metric_event_stream_client_connected.inc() server.clients.append(ws) @@ -48,7 +48,7 @@ async def on_websocket(self, _: Request, ws: WebSocket): logger.debug(f'Received websocket message: {message}') finally: - logger.debug(f'Client disconnected from Event Stream') + logger.debug('Client disconnected from Event Stream') metric_event_stream_client_disconnected.inc() server.clients.remove(ws) @@ -72,7 +72,7 @@ def watch_database_events(self): jobs = list_job_registry(self.config) current_jobs: dict[str, JobDto] = {job.id: job for job in jobs} if last_jobs is not None and current_jobs != last_jobs: - logger.debug(f'Detected change in job models') + logger.debug('Detected change in job models') self.notify_clients({ 'event': 'job_models_changed', }) diff --git a/lifecycle/lifecycle/monitor/base.py b/lifecycle/lifecycle/monitor/base.py index fef0436e1..db21903b9 100644 --- a/lifecycle/lifecycle/monitor/base.py +++ b/lifecycle/lifecycle/monitor/base.py @@ -17,7 +17,7 @@ def list_jobs(self, config: Config) -> Iterable[JobDto]: def check_job_condition(self, job: JobDto, deployment_timestamp: int = 0, - on_job_alive: Callable = None, + on_job_alive: Callable | None = None, logs_on_error: bool = True, ): """ diff --git a/lifecycle/lifecycle/monitor/health.py b/lifecycle/lifecycle/monitor/health.py index d0d148713..767239fd5 100644 --- a/lifecycle/lifecycle/monitor/health.py +++ b/lifecycle/lifecycle/monitor/health.py @@ -10,7 +10,7 @@ def check_until_job_is_operational( base_url: str, deployment_timestamp: int = 0, - on_job_alive: Callable = None, + on_job_alive: Callable | None = None, headers: dict[str, str] | None = None, ): """ diff --git a/lifecycle/lifecycle/monitor/monitors.py b/lifecycle/lifecycle/monitor/monitors.py index 20139ae9b..55ec5af66 100644 --- a/lifecycle/lifecycle/monitor/monitors.py +++ b/lifecycle/lifecycle/monitor/monitors.py @@ -14,6 +14,9 @@ def list_infrastructure_jobs(config: Config, plugin_engine: PluginEngine) -> Ite infrastructures = list_infrastructure_targets(plugin_engine) for infrastructure in infrastructures: try: + if infrastructure.job_monitor is None: + raise ValueError("job monitor is None") + yield from infrastructure.job_monitor.list_jobs(config) except BaseException as e: log_exception(ContextError(f'failed to list jobs from {infrastructure}', e)) @@ -27,6 +30,10 @@ def check_job_condition(job: JobDto, on_job_alive: Callable): (server running already, but still initializing) """ infrastructure = get_infrastructure_target(job.infrastructure_target) + + if infrastructure.job_monitor is None: + raise ValueError("job monitor is None") + infrastructure.job_monitor.check_job_condition(job, job.update_time, on_job_alive, logs_on_error=True) @@ -35,11 +42,15 @@ def read_recent_logs(job: JobDto, tail: int) -> str: """Return last output logs from a job""" with wrap_context('reading Job logs'): infrastructure = get_infrastructure_target(job.infrastructure_target) + + if infrastructure.job_monitor is None: + raise ValueError("job monitor is None") + return infrastructure.job_monitor.read_recent_logs(job, tail=tail) def list_log_streamers( plugin_engine: PluginEngine, -) -> list[LogsStreamer]: +) -> list[LogsStreamer | None]: infrastructures = list_infrastructure_targets(plugin_engine) return [infra.logs_streamer for infra in infrastructures] diff --git a/lifecycle/lifecycle/server/socketio.py b/lifecycle/lifecycle/server/socketio.py index a3dd8e408..3f59ea87a 100644 --- a/lifecycle/lifecycle/server/socketio.py +++ b/lifecycle/lifecycle/server/socketio.py @@ -67,6 +67,8 @@ def subscribe_for_logs(client_id: str, data: dict) -> str: except BaseException as e: log_exception(e) + return "" + @self.sio.event def disconnect(client_id: str): if client_id in self.log_sessions_by_client: @@ -79,13 +81,16 @@ def open_logs_session(self, client_id: str, resource_properties: dict[str, str]) logger.info(f'Creating log session for client: {client_id}') job_name = resource_properties['job_name'] job_version = resource_properties['job_version'] - tail = resource_properties.get('tail') + tail = resource_properties.get('tail', '') job = self.job_retriever.get_job(job_name, job_version) job_version = job.version # contains resolved (non-aliased) version session_id = f'{client_id}_{job_name}_{job_version}' infrastructure = get_infrastructure_target(job.infrastructure_target) + if infrastructure.logs_streamer is None: + raise ValueError("logs streamer is None") + session = LogSessionDetails( client_id=client_id, job_name=job_name, diff --git a/lifecycle/lifecycle/supervisor/cleanup.py b/lifecycle/lifecycle/supervisor/cleanup.py index f6fa04553..dcacb6c70 100644 --- a/lifecycle/lifecycle/supervisor/cleanup.py +++ b/lifecycle/lifecycle/supervisor/cleanup.py @@ -16,7 +16,7 @@ def clean_up_async_job_calls(config: Config): mapper = LifecycleCache.record_mapper() condition = QueryCondition(f'started_at < {mapper.placeholder}', older_than) records = mapper.filter(tables.AsyncJobCall, condition=condition) - + if records: for record in records: mapper.delete_record(record) diff --git a/lifecycle/tests/auth/test_password.py b/lifecycle/tests/auth/test_password.py index 077affba3..9bfd61cb0 100644 --- a/lifecycle/tests/auth/test_password.py +++ b/lifecycle/tests/auth/test_password.py @@ -5,19 +5,21 @@ def test_password_verification(): test_password = "complex_password_123!" hashed = make_password(test_password) - + assert check_password(test_password, hashed) assert not check_password("wrong_password", hashed) + def test_password_uniqueness(): password = "test_password" hash1 = make_password(password) hash2 = make_password(password) - + assert hash1 != hash2 # Different salts should produce different hashes assert check_password(password, hash1) assert check_password(password, hash2) + def test_password_compatibility(): django_hash = "pbkdf2_sha256$600000$ToUqATye4PAvPsSe9rzUVY$eMWAQNMhb1L22JOXMI92AzSmvUtZKeTzZvbWyJMvakE=" diff --git a/lifecycle/tests/database/test_table_records.py b/lifecycle/tests/database/test_table_records.py index 31e3a0668..c7e49eb76 100644 --- a/lifecycle/tests/database/test_table_records.py +++ b/lifecycle/tests/database/test_table_records.py @@ -45,7 +45,7 @@ def test_record_operations(): mapper.delete(JobFamily, id=record_id) assert mapper.count(JobFamily) == 0 - + try: mapper.delete(JobFamily, id=record_id) assert False, 'it should raise NoRowsAffected exception' diff --git a/lifecycle/tests/server/test_logs_stream.py b/lifecycle/tests/server/test_logs_stream.py index d538fde00..2f42a50ee 100644 --- a/lifecycle/tests/server/test_logs_stream.py +++ b/lifecycle/tests/server/test_logs_stream.py @@ -40,7 +40,7 @@ def create_session(self, session_id: str, resource_properties: Dict[str, str], o on_next_line(session_id, f'hello {job_name}') def in_background(): - on_next_line(session_id, f'more logs') + on_next_line(session_id, 'more logs') Thread(target=in_background, daemon=True).start() diff --git a/lifecycle/tests/server/test_version.py b/lifecycle/tests/server/test_version.py index 589fa51cd..2d426803a 100644 --- a/lifecycle/tests/server/test_version.py +++ b/lifecycle/tests/server/test_version.py @@ -21,4 +21,3 @@ def test_health_version_endpoint(): obj = response.json() assert obj['git_version'] == '0.0.1-g32c4b29-dirty' assert obj['live'] is True - diff --git a/racetrack_client/racetrack_client/client/deploy.py b/racetrack_client/racetrack_client/client/deploy.py index 929e20eed..6c1db8925 100644 --- a/racetrack_client/racetrack_client/client/deploy.py +++ b/racetrack_client/racetrack_client/client/deploy.py @@ -65,7 +65,7 @@ def send_deploy_request( lifecycle_url: Optional[str] = None, force: bool = False, build_context_method: BuildContextMethod = BuildContextMethod.default, - extra_vars: Dict[str, str] = None, + extra_vars: Optional[Dict[str, str]] = None, build_flags: List[str] = [], ): """ @@ -90,7 +90,7 @@ def send_deploy_request( if client_config is None: client_config = load_client_config() manifest: Manifest = load_validated_manifest(workdir, extra_vars) - manifest_dict: Dict = load_merged_manifest_dict(get_manifest_path(workdir), extra_vars) + manifest_dict: Dict = load_merged_manifest_dict(get_manifest_path(workdir), extra_vars or {}) logger.debug(f'Manifest loaded: {manifest}') lifecycle_url = resolve_lifecycle_url(client_config, lifecycle_url) diff --git a/racetrack_client/racetrack_client/client/env.py b/racetrack_client/racetrack_client/client/env.py index 44f7e8ab2..d9f57d54d 100644 --- a/racetrack_client/racetrack_client/client/env.py +++ b/racetrack_client/racetrack_client/client/env.py @@ -37,7 +37,6 @@ def read_secret_vars_from_file(workdir: str, secret_env_file: Optional[str], var def read_env_vars_from_file(path: Path) -> Dict[str, str]: with wrap_context('reading vars from env file'): config = ConfigParser() - config.optionxform = str config.read_string("[config]\n" + path.read_text()) env_dict = dict() for k, v in config["config"].items(): diff --git a/racetrack_client/racetrack_client/client/logs.py b/racetrack_client/racetrack_client/client/logs.py index 56e7a6b2e..298d2c5c5 100644 --- a/racetrack_client/racetrack_client/client/logs.py +++ b/racetrack_client/racetrack_client/client/logs.py @@ -26,7 +26,7 @@ def show_runtime_logs(name: str, version: str, remote: Optional[str], tail: int, headers=get_auth_request_headers(user_auth), ) response = parse_response_object(r, 'Lifecycle response error') - exact_version = response.get('version') + exact_version = response.get('version', '') logger.info(f'Retrieving runtime logs of job "{name}" {exact_version} from {lifecycle_url}...') diff --git a/racetrack_client/racetrack_client/client/manage.py b/racetrack_client/racetrack_client/client/manage.py index 8a2e77214..20907cadf 100644 --- a/racetrack_client/racetrack_client/client/manage.py +++ b/racetrack_client/racetrack_client/client/manage.py @@ -130,7 +130,7 @@ def complete_job_name(incomplete: str) -> List[str]: headers=get_auth_request_headers(user_auth), ) jobs: List[Dict] = parse_response_list(r, 'Lifecycle response error') - job_names = [job.get('name') for job in jobs] + job_names = [job.get('name', '') for job in jobs] for name in job_names: if name.startswith(incomplete): diff --git a/racetrack_client/racetrack_client/client/run.py b/racetrack_client/racetrack_client/client/run.py index 0efe7f2fa..fa833aa3f 100644 --- a/racetrack_client/racetrack_client/client/run.py +++ b/racetrack_client/racetrack_client/client/run.py @@ -27,13 +27,13 @@ def run_job_locally( lifecycle_url: str, build_context_method: BuildContextMethod = BuildContextMethod.default, port: Optional[int] = None, - extra_vars: Dict[str, str] = None, + extra_vars: Optional[Dict[str, str]] = None, build_flags: List[str] = [], cmd: Optional[str] = None, ): client_config = load_client_config() manifest: Manifest = load_validated_manifest(workdir, extra_vars) - manifest_dict: Dict = load_merged_manifest_dict(get_manifest_path(workdir), extra_vars) + manifest_dict: Dict = load_merged_manifest_dict(get_manifest_path(workdir), extra_vars or {}) lifecycle_url = resolve_lifecycle_url(client_config, lifecycle_url) user_auth = get_user_auth(client_config, lifecycle_url) diff --git a/racetrack_client/racetrack_client/log/logs.py b/racetrack_client/racetrack_client/log/logs.py index 0739f568f..7c61eeb2d 100644 --- a/racetrack_client/racetrack_client/log/logs.py +++ b/racetrack_client/racetrack_client/log/logs.py @@ -25,13 +25,19 @@ def configure_logs(log_level: Optional[str] = None): """Configure root logger with a log level""" - log_level = log_level or os.environ.get('LOG_LEVEL', 'debug') - level = _parse_logging_level(log_level) + log_level_str: str + if log_level: + log_level_str = log_level + else: + log_level_str = os.environ.get('LOG_LEVEL', 'debug') + level: int = _parse_logging_level(log_level_str) # Set root level to INFO to avoid printing a ton of garbage DEBUG logs from imported libraries log_format = LOG_FORMAT_DEBUG if debug_format_enabled else LOG_FORMAT logging.basicConfig(stream=sys.stdout, format=log_format, level=logging.INFO, datefmt=LOG_DATE_FORMAT, force=True) original_formatter = logging.getLogger().handlers[0].formatter + + formatter: logging.Formatter if structured_logs_on: formatter = StructuredFormatter() else: diff --git a/racetrack_client/racetrack_client/main.py b/racetrack_client/racetrack_client/main.py index 9cbf34099..28f4ce9db 100644 --- a/racetrack_client/racetrack_client/main.py +++ b/racetrack_client/racetrack_client/main.py @@ -223,7 +223,7 @@ def _get_remote( @cli_get.command('pub') -def _get_remote( +def _get_pub( quiet: bool = typer.Option(False, '--quiet', '-q', help='print only the URL address'), ): """Get current Racetrack's Pub address""" diff --git a/racetrack_client/racetrack_client/manifest/manifest.py b/racetrack_client/racetrack_client/manifest/manifest.py index fa887298a..4439748fa 100644 --- a/racetrack_client/racetrack_client/manifest/manifest.py +++ b/racetrack_client/racetrack_client/manifest/manifest.py @@ -103,7 +103,7 @@ def get_deprecated_fields(self): 'wrapper_properties': 'jobtype_extra:' } - def get_jobtype(self) -> str: + def get_jobtype(self) -> Optional[str]: return self.jobtype if self.jobtype else self.lang def get_jobtype_extra(self) -> Dict[str, Any]: diff --git a/racetrack_client/racetrack_client/manifest/validate.py b/racetrack_client/racetrack_client/manifest/validate.py index 014c9c116..ca34b8c3a 100644 --- a/racetrack_client/racetrack_client/manifest/validate.py +++ b/racetrack_client/racetrack_client/manifest/validate.py @@ -1,6 +1,6 @@ from pathlib import Path import json -from typing import Dict +from typing import Dict, Optional from jsonschema import validate @@ -17,7 +17,7 @@ def load_validated_manifest( path: str, - extra_vars: Dict[str, str] = None, + extra_vars: Optional[Dict[str, str]] = None, ) -> Manifest: """ Load and validate manifest from a path. Raise exception in case of a defect. @@ -51,7 +51,7 @@ def validate_manifest(manifest: Manifest): def validate_and_show_manifest( path: str, - extra_vars: Dict[str, str] = None, + extra_vars: Optional[Dict[str, str]] = None, ): manifest = load_validated_manifest(path, extra_vars) logger.info(f'Manifest file "{path}" is valid') diff --git a/racetrack_client/racetrack_client/plugin/bundler/filename_matcher.py b/racetrack_client/racetrack_client/plugin/bundler/filename_matcher.py index 1f971b23f..ba31e5be8 100644 --- a/racetrack_client/racetrack_client/plugin/bundler/filename_matcher.py +++ b/racetrack_client/racetrack_client/plugin/bundler/filename_matcher.py @@ -1,6 +1,6 @@ from pathlib import Path import fnmatch -from typing import List, Iterable +from typing import List, Iterable, Optional DEFAULT_IGNORE_PATTERNS = [ '*.zip', @@ -28,7 +28,7 @@ class FilenameMatcher: -whole_dir/but_this/without_that """ - def __init__(self, file_patterns: List[str] = None, apply_defaults: bool = True) -> None: + def __init__(self, file_patterns: Optional[List[str]] = None, apply_defaults: bool = True) -> None: self.patterns: List[str] = [] if file_patterns: diff --git a/racetrack_client/racetrack_client/utils/shell.py b/racetrack_client/racetrack_client/utils/shell.py index 43673f89d..0a52046d7 100644 --- a/racetrack_client/racetrack_client/utils/shell.py +++ b/racetrack_client/racetrack_client/utils/shell.py @@ -78,6 +78,9 @@ def shell( if read_bytes: while True: + if process.stdout is None: + raise CommandError(cmd, 'could not read stdout', process.returncode) + chunk: bytes = process.stdout.read(1) if chunk == b'': break @@ -91,6 +94,9 @@ def shell( captured_stream.write(chunk_str) else: + if process.stdout is None: + raise CommandError(cmd, 'could not read stdout', process.returncode) + for line in iter(process.stdout.readline, b''): line_str = line.decode() @@ -118,7 +124,7 @@ class CommandOutputStream: def __init__(self, cmd: str, on_next_line: Callable[[str], None], - on_error: Callable[['CommandError'], None] = None, + on_error: Optional[Callable[['CommandError'], None]] = None, workdir: Optional[Path] = None, print_stdout: bool = False): """ diff --git a/racetrack_commons/racetrack_commons/api/asgi/asgi_server.py b/racetrack_commons/racetrack_commons/api/asgi/asgi_server.py index a76ee6cc6..bd1c10900 100644 --- a/racetrack_commons/racetrack_commons/api/asgi/asgi_server.py +++ b/racetrack_commons/racetrack_commons/api/asgi/asgi_server.py @@ -156,9 +156,9 @@ def format(self, record: logging.LogRecord): class NeedlessRequestsFilter(logging.Filter): def filter(self, record: logging.LogRecord): - method: str = record.args[1] - uri: str = record.args[2] - response_code: int = record.args[4] + method: str = record.args[1] # type: ignore + uri: str = record.args[2] # type: ignore + response_code: int = record.args[4] # type: ignore log_line = f'{method} {uri} {response_code}' if log_line in HIDDEN_ACCESS_LOGS: return False diff --git a/racetrack_commons/racetrack_commons/api/asgi/fastapi.py b/racetrack_commons/racetrack_commons/api/asgi/fastapi.py index 795d0782f..8c7dd2d16 100644 --- a/racetrack_commons/racetrack_commons/api/asgi/fastapi.py +++ b/racetrack_commons/racetrack_commons/api/asgi/fastapi.py @@ -101,6 +101,6 @@ def custom_openapi(): fastapi_app.openapi_schema = openapi_schema return fastapi_app.openapi_schema - fastapi_app.openapi = custom_openapi + fastapi_app.openapi = custom_openapi # type: ignore return fastapi_app diff --git a/racetrack_commons/racetrack_commons/api/metrics.py b/racetrack_commons/racetrack_commons/api/metrics.py index 10ff86602..fe0d20239 100644 --- a/racetrack_commons/racetrack_commons/api/metrics.py +++ b/racetrack_commons/racetrack_commons/api/metrics.py @@ -29,7 +29,7 @@ def setup_metrics_endpoint(api: FastAPI): metrics_app = make_wsgi_app(REGISTRY) - api.mount('/metrics', WSGIMiddleware(metrics_app)) + api.mount('/metrics', WSGIMiddleware(metrics_app)) # type: ignore TrailingSlashForwarder.mount_path('/metrics') @api.get('/metrics', tags=['root']) diff --git a/racetrack_commons/racetrack_commons/api/server_sent_events.py b/racetrack_commons/racetrack_commons/api/server_sent_events.py index 45444b8a9..cf265b865 100644 --- a/racetrack_commons/racetrack_commons/api/server_sent_events.py +++ b/racetrack_commons/racetrack_commons/api/server_sent_events.py @@ -20,7 +20,7 @@ def stream_result_with_heartbeat(result_runner: Callable[[], Dict]): """ Return result dict in SSE (Server-Sent Events) response, streaming heartbeat events to keep the connection alive """ - result_channel = queue.Queue(maxsize=0) + result_channel: queue.Queue[str] = queue.Queue(maxsize=0) def _runner(): try: diff --git a/racetrack_commons/racetrack_commons/deploy/job_type.py b/racetrack_commons/racetrack_commons/deploy/job_type.py index e7336d097..cb7316591 100644 --- a/racetrack_commons/racetrack_commons/deploy/job_type.py +++ b/racetrack_commons/racetrack_commons/deploy/job_type.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from enum import Enum from pathlib import Path +from typing_extensions import Any import backoff @@ -138,12 +139,12 @@ def _parse_job_type_definition( ) if isinstance(job_type_value, dict): - images_dicts: list[dict] = job_type_value.get('images', []) + images_dicts: list[dict[str, Any]] = job_type_value.get('images', []) for images_dict in images_dicts: source: str = images_dict.get('source', 'jobtype') assert source in {'jobtype', 'job'}, "'source' can be either 'jobtype' or 'job'" source_enum = ImageSourceLocation(source) - dockerfile_path_str: str = images_dict.get('dockerfile_path') + dockerfile_path_str: str = images_dict.get('dockerfile_path', '') assert dockerfile_path_str, '"dockerfile_path" is not specified in job type data' dockerfile_path = Path(dockerfile_path_str) if source_enum == ImageSourceLocation.jobtype: diff --git a/racetrack_commons/racetrack_commons/plugin/loader.py b/racetrack_commons/racetrack_commons/plugin/loader.py index 21a24613c..a21ea3fb7 100644 --- a/racetrack_commons/racetrack_commons/plugin/loader.py +++ b/racetrack_commons/racetrack_commons/plugin/loader.py @@ -132,6 +132,7 @@ def _load_plugin_class(plugin_dir: Path, config_path: Path, plugin_manifest: Plu with wrap_context(f'loading plugin class'): module_name = f'racetrack_plugin_{random.randint(0, 999999)}' spec = importlib.util.spec_from_file_location(module_name, plugin_filename) + assert spec is not None, 'no module spec' ext_module = importlib.util.module_from_spec(spec) loader: Optional[Loader] = spec.loader assert loader is not None, 'no module loader' diff --git a/requirements-dev.txt b/requirements-dev.txt index bf13aaa31..6c65acf54 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,7 +1,10 @@ # Developer tools -mypy==0.991 -flake8==6.0.0 -pylint==2.15.8 +mypy==1.16.1 +flake8==7.3.0 +pylint==3.3.7 black==24.3.0 twine==6.1.0 wheel==0.38.4 +types-PyYAML==6.0.12 +types-python-dateutil==2.9.0.20250708 +types-Markdown==3.3.18 \ No newline at end of file