-
Notifications
You must be signed in to change notification settings - Fork 1
refactor: add type annotations and improve type handling #363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
dc73ec2
3acca2b
dbdfac5
81a77ad
e68d932
b1cb0f4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,5 +1,5 @@ | ||||||||||
| from dataclasses import dataclass | ||||||||||
| from typing import Literal | ||||||||||
| from typing import Literal, Self | ||||||||||
|
|
||||||||||
|
Comment on lines
+2
to
3
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chainFix Optional typing for fields that default to These fields are annotated as non-Optional but initialized with -from typing import Literal, Self
+from typing import Literal, Self
@@
- infer_type: InferenceType = None
+ infer_type: InferenceType | None = None
@@
- model_conf: ModelConfig = None
+ model_conf: ModelConfig | None = None
@@
- decoder_layers: int = None
- motion_dim: int = None
- mfcc: bool = None
- face_scale: bool = None
- face_location: bool = None
+ decoder_layers: int | None = None
+ motion_dim: int | None = None
+ mfcc: bool | None = None
+ face_scale: bool | None = None
+ face_location: bool | None = NoneAlso, confirm project Python version; Also applies to: 30-33, 50-53, 92-97 Length of output: 9377 🏁 Script executed: #!/bin/bash
# Check Python version requirement in project config
fd -e pyproject.toml -e setup.py -e setup.cfg | head -5 | xargs cat 2>/dev/nullLength of output: 0 🏁 Script executed: #!/bin/bash
# Check if there's a Python version constraint in README or .python-version
fd -e README.md -e .python-version | xargs cat 2>/dev/null | grep -i pythonLength of output: 0 🏁 Script executed: #!/bin/bash
# Search for pyproject.toml or setup.py anywhere in repo
find . -maxdepth 3 -type f \( -name "pyproject.toml" -o -name "setup.py" -o -name "setup.cfg" \) 2>/dev/null | head -10Length of output: 16 🏁 Script executed: #!/bin/bash
# Check common root files for Python version info
ls -la | grep -E "^-.*\.(toml|py|cfg|txt|yml|yaml)$"Length of output: 324 🏁 Script executed: #!/bin/bash
# Attempt to find any Python version constraints via broader search
rg -i "python.*3\.[0-9]+" -B2 -A2 | head -30Length of output: 1510 Import The project requires Python 3.10 ( -from typing import Literal, Self
+from typing import Literal
+from typing_extensions import SelfThen update these field annotations: - infer_type: InferenceType = None
+ infer_type: InferenceType | None = None
- model_conf: ModelConfig = None
+ model_conf: ModelConfig | None = None
- decoder_layers: int = None
- motion_dim: int = None
- mfcc: bool = None
- face_scale: bool = None
- face_location: bool = None
+ decoder_layers: int | None = None
+ motion_dim: int | None = None
+ mfcc: bool | None = None
+ face_scale: bool | None = None
+ face_location: bool | None = NoneAlso applies to any other occurrences in lines 30–33, 50–53, 92–97. 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||
| from visualizr.anitalker.choices import ( | ||||||||||
| Activation, | ||||||||||
|
|
@@ -24,17 +24,12 @@ | |||||||||
| ModelConfig, | ||||||||||
| ) | ||||||||||
| from visualizr.anitalker.model.latentnet import LatentNetType, MLPSkipNetConfig | ||||||||||
| from visualizr.app.types import InferenceType | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @dataclass | ||||||||||
| class TrainConfig(BaseConfig): | ||||||||||
| infer_type: Literal[ | ||||||||||
| "mfcc_full_control", | ||||||||||
| "mfcc_pose_only", | ||||||||||
| "hubert_pose_only", | ||||||||||
| "hubert_audio_only", | ||||||||||
| "hubert_full_control", | ||||||||||
| ] = None | ||||||||||
| infer_type: InferenceType = None | ||||||||||
| # random seed | ||||||||||
| seed: int = 0 | ||||||||||
| train_mode: TrainMode = TrainMode.diffusion | ||||||||||
|
|
@@ -94,19 +89,22 @@ class TrainConfig(BaseConfig): | |||||||||
| T: int = 1_000 | ||||||||||
| # to be overridden | ||||||||||
| name: str = "" | ||||||||||
| decoder_layers = None | ||||||||||
| motion_dim = None | ||||||||||
| decoder_layers: int = None | ||||||||||
| motion_dim: int = None | ||||||||||
| mfcc: bool = None | ||||||||||
| face_scale: bool = None | ||||||||||
| face_location: bool = None | ||||||||||
|
|
||||||||||
| def __post_init__(self): | ||||||||||
| def __post_init__(self) -> None: | ||||||||||
| self.batch_size_eval = self.batch_size_eval or self.batch_size | ||||||||||
| self.data_val_name = self.data_val_name or self.data_name | ||||||||||
|
|
||||||||||
| def scale_up_gpus(self, num_gpus, num_nodes=1): | ||||||||||
| def scale_up_gpus(self, num_gpus: int, num_nodes: int = 1) -> Self: | ||||||||||
| self.batch_size *= num_gpus * num_nodes | ||||||||||
| self.batch_size_eval *= num_gpus * num_nodes | ||||||||||
| return self | ||||||||||
|
|
||||||||||
| def _make_diffusion_conf(self, t: int): | ||||||||||
| def _make_diffusion_conf(self, t: int) -> SpacedDiffusionBeatGansConfig: | ||||||||||
| if self.diffusion_type != "beatgans": | ||||||||||
| raise NotImplementedError | ||||||||||
| # can use t < `self.t` for evaluation | ||||||||||
|
|
@@ -132,21 +130,21 @@ def _make_diffusion_conf(self, t: int): | |||||||||
| ) | ||||||||||
|
|
||||||||||
| @property | ||||||||||
| def model_out_channels(self): | ||||||||||
| def model_out_channels(self) -> int: | ||||||||||
| return 3 | ||||||||||
|
|
||||||||||
| def make_t_sampler(self) -> UniformSampler: | ||||||||||
| if self.T_sampler != "uniform": | ||||||||||
| raise NotImplementedError | ||||||||||
| return UniformSampler(self.T) | ||||||||||
|
|
||||||||||
| def make_diffusion_conf(self): | ||||||||||
| def make_diffusion_conf(self) -> SpacedDiffusionBeatGansConfig: | ||||||||||
| return self._make_diffusion_conf(self.T) | ||||||||||
|
|
||||||||||
| def make_eval_diffusion_conf(self): | ||||||||||
| def make_eval_diffusion_conf(self) -> SpacedDiffusionBeatGansConfig: | ||||||||||
| return self._make_diffusion_conf(self.T_eval) | ||||||||||
|
|
||||||||||
| def make_model_conf(self): | ||||||||||
| def make_model_conf(self) -> BeatGANsAutoencConfig | BeatGANsUNetConfig: | ||||||||||
| if self.model_name == ModelName.beatgans_ddpm: | ||||||||||
| self.model_type = ModelType.ddpm | ||||||||||
| self.model_conf = BeatGANsUNetConfig( | ||||||||||
|
|
@@ -172,16 +170,9 @@ def make_model_conf(self): | |||||||||
| resnet_two_cond=self.net_beatgans_resnet_two_cond, | ||||||||||
| resnet_use_zero_module=self.net_beatgans_resnet_use_zero_module, | ||||||||||
| ) | ||||||||||
| elif self.model_name in [ | ||||||||||
| ModelName.beatgans_autoenc, | ||||||||||
| ]: | ||||||||||
| elif self.model_name == ModelName.beatgans_autoenc: | ||||||||||
| cls = BeatGANsAutoencConfig | ||||||||||
| # supports both autoenc and vaeddpm | ||||||||||
| if self.model_name == ModelName.beatgans_autoenc: | ||||||||||
| self.model_type = ModelType.autoencoder | ||||||||||
| else: | ||||||||||
| raise NotImplementedError | ||||||||||
|
|
||||||||||
| self.model_type = ModelType.autoencoder | ||||||||||
| if self.net_latent_net_type == LatentNetType.none: | ||||||||||
| latent_net_conf = None | ||||||||||
| elif self.net_latent_net_type == LatentNetType.skip: | ||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,41 +17,40 @@ class BaseConfig: | |
| and serialize/deserialize configurations to/from JSON. | ||
| """ | ||
|
|
||
| def inherit(self, another): | ||
| def inherit(self, another) -> None: | ||
| """Inherit common keys from a given config.""" | ||
| common_keys = set(self.__dict__.keys()) & set(another.__dict__.keys()) | ||
| common_keys: set[str] = set(self.__dict__.keys()) & set(another.__dict__.keys()) | ||
| for k in common_keys: | ||
| setattr(self, k, getattr(another, k)) | ||
|
|
||
| def propagate(self): | ||
| def propagate(self) -> None: | ||
| """Push down the configuration to all members.""" | ||
| for _, v in self.__dict__.items(): | ||
| if isinstance(v, BaseConfig): | ||
| v.inherit(self) | ||
| v.propagate() | ||
|
|
||
| def from_dict(self, config_dict, strict=False): | ||
| def from_dict(self, config_dict: dict, strict: bool = False) -> None: | ||
| """ | ||
| Populate configuration attributes from a dictionary. | ||
|
|
||
| Optionally, enforcing strict key checking. | ||
| """ | ||
| for k, v in config_dict.items(): | ||
| if not hasattr(self, k): | ||
| _msg: str = f"loading extra '{k}'" | ||
| if strict: | ||
| raise ValueError(f"loading extra '{k}'") | ||
| _msg = f"loading extra '{k}'" | ||
| logger.info(_msg) | ||
| Info(_msg) | ||
| raise ValueError(_msg) | ||
| logger.warning(_msg) | ||
|
Comment on lines
+41
to
+44
|
||
| continue | ||
| if isinstance(self.__dict__[k], BaseConfig): | ||
| self.__dict__[k].from_dict(v) | ||
| else: | ||
| self.__dict__[k] = v | ||
|
|
||
| def as_dict_jsonable(self): | ||
| def as_dict_jsonable(self) -> dict: | ||
| """Convert the configuration to a JSON-serializable dictionary.""" | ||
| conf = {} | ||
| conf: dict = {} | ||
| for k, v in self.__dict__.items(): | ||
| if isinstance(v, BaseConfig): | ||
| conf[k] = v.as_dict_jsonable() | ||
|
|
@@ -61,10 +60,10 @@ def as_dict_jsonable(self): | |
|
|
||
|
|
||
| @lru_cache | ||
| def jsonable(x: Any) -> bool: | ||
| def jsonable(x) -> bool: | ||
| """Check if the object x is JSON serializable.""" | ||
| try: | ||
| dumps(x) | ||
| return True | ||
| except TypeError: | ||
| return False | ||
| return True | ||
|
Comment on lines
65
to
+69
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'Literal' is not used.