-
Notifications
You must be signed in to change notification settings - Fork 1
refactor: add type annotations and improve type handling #363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…string to check_missing_paths method
…le modules - Added type hints to functions and class methods in discriminator.py, encoder.py, styledecoder.py, templates.py, and utils.py for better clarity and type checking. - Updated the fused_leaky_relu function and its corresponding class to use type annotations for input parameters and return types. - Enhanced the upfirdn2d_native and upfirdn2d functions with type hints for input parameters and return types. - Refactored the make_kernel function to include type hints and improved its documentation. - Modified various classes (e.g., Blur, EqualConv2d, EqualLinear) to include type hints for attributes and method parameters. - Improved logging messages to include type hints where applicable. - Ensured consistent use of type hints across all modules for better maintainability and readability.
|
Review these changes at https://app.gitnotebooks.com/AlphaSphereDotAI/visualizr/pull/363 |
Reviewer's GuideThis PR uniformly adds static type annotations and local variable typing throughout the codebase, refactors exception handling to standardize error messages, enhances typing imports and TYPE_CHECKING usage, updates Pydantic model field types, and refines config/enumeration declarations, improving type safety and code clarity. Class diagram for updated TrainConfig and related config classesclassDiagram
class TrainConfig {
+InferenceType infer_type
+int seed
+TrainMode train_mode
+ModelType model_type
+ModelName model_name
+BeatGANsAutoencConfig|BeatGANsUNetConfig model_conf
+int decoder_layers
+int motion_dim
+bool mfcc
+bool face_scale
+bool face_location
+void __post_init__()
+Self scale_up_gpus(int num_gpus, int num_nodes=1)
+SpacedDiffusionBeatGansConfig _make_diffusion_conf(int t)
+int model_out_channels()
+UniformSampler make_t_sampler()
+SpacedDiffusionBeatGansConfig make_diffusion_conf()
+SpacedDiffusionBeatGansConfig make_eval_diffusion_conf()
+BeatGANsAutoencConfig|BeatGANsUNetConfig make_model_conf()
}
class BeatGANsAutoencConfig {
<<inherits BeatGANsUNetConfig>>
+tuple[int]|None use_timesteps
+SpacedDiffusionBeatGans make_sampler()
}
class BeatGANsUNetConfig {
+int model_channels
+int out_channels
+int num_res_blocks
+int|None num_input_res_blocks
+int embed_channels
+tuple[int] attention_resolutions
+int|None time_embed_channels
+tuple[int] channel_mult
+tuple[int]|None input_channel_mult
+int dims
+int num_heads
+int num_heads_upsample
+bool resblock_updown
+bool use_new_attention_order
+bool resnet_two_cond
+int|None resnet_cond_channels
+bool resnet_use_zero_module
+bool attn_checkpoint
+BeatGANsUNetModel make_model()
}
TrainConfig --> BeatGANsAutoencConfig
TrainConfig --> BeatGANsUNetConfig
BeatGANsAutoencConfig --|> BeatGANsUNetConfig
Class diagram for updated Enum types (choices.py)classDiagram
class TrainMode {
+str manipulate
+str diffusion
}
class ModelType {
+str ddpm
+str autoencoder
+bool has_autoenc()
}
class ModelName {
+str beatgans_ddpm
+str beatgans_autoenc
}
class ModelMeanType {
+str eps
}
class ModelVarType {
+str fixed_small
+str fixed_large
}
class LossType {
+str mse
}
class GenerativeType {
+str ddpm
+str ddim
}
class Activation {
+str none
+str relu
+str lrelu
+str silu
+str tanh
+Identity|ReLU|LeakyReLU|SiLU|Tanh get_act()
}
Class diagram for updated LitModel and DiffusionPredictorclassDiagram
class LitModel {
+TrainConfig conf
+DiffusionPredictor model
+DiffusionPredictor ema_model
+SpacedDiffusionBeatGans sampler
+SpacedDiffusionBeatGans eval_sampler
+UniformSampler T_sampler
+Tensor x_T
+void render(...)
+Tensor forward(...)
}
class DiffusionPredictor {
+TrainConfig conf
+Encoder speech_encoder
+Encoder coarse_decoder
+Tensor weights
+Tensor out_proj
+Tensor init_code_proj
+Tensor direction_code_proj
+Tensor pose_predictor
+Tensor pose_encoder
+Encoder create_conformer_encoder(...)
+Tensor forward(...)
+Tensor adjust_features(...)
+Tensor adjust_pose(...)
+Tensor combine_features(...)
+Tensor decode_features(...)
}
LitModel --> DiffusionPredictor
Class diagram for updated ModelSettings (Pydantic)classDiagram
class ModelSettings {
+int step_t
+int seed
+int motion_dim
+FilePath|None image_path
+FilePath|None audio_path
+bool control_flag
+str pose_driven_path
+int image_size
+int batch_size
+Checkpoint checkpoint
+ModelSettings check_missing_paths()
}
Class diagram for updated LiaModelclassDiagram
class LiaModel {
+int size
+int style_dim
+int motion_dim
+int channel_multiplier
+list|None blur_kernel
+str fusion_type
+void load_lightning_model(...)
}
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughSummary by CodeRabbit
WalkthroughComprehensive type annotation overhaul across the codebase, replacing implicit or missing type hints with explicit function signatures, parameter annotations, and return types. Changes span configuration, diffusion models, neural network layers, and utilities. Includes minor behavioral adjustments: new config fields, updated validation logic, and refined error message handling. No substantial logic changes; primarily enhances static type checking. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Possibly related PRs
Suggested labels
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
Hi @MH0386, Your PR is in conflict and cannot be merged. |
Summary of ChangesHello @MH0386, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on a significant refactoring effort to enhance the codebase's type safety and overall quality. By introducing comprehensive type annotations and refining existing type handling, the changes aim to make the code more robust, easier to understand, and less prone to type-related errors. This refactor touches various modules, from configuration and diffusion models to utility functions and application settings, ensuring a consistent and improved developer experience. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey there - I've reviewed your changes - here's some feedback:
- There are still several “# TODO: type hint” markers scattered in the code—please fill in those missing annotations or remove the TODOs to keep the typing consistent.
- This PR adopts the Python 3.10 union syntax (e.g.
X | None)—ensure the project’s minimum Python version is bumped accordingly in CI/config or switch toOptional[...]for broader compatibility. - Given the extensive new type annotations, consider adding a mypy (or similar) check to CI to catch mismatches between annotations and actual code usage early on.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- There are still several “# TODO: type hint” markers scattered in the code—please fill in those missing annotations or remove the TODOs to keep the typing consistent.
- This PR adopts the Python 3.10 union syntax (e.g. `X | None`)—ensure the project’s minimum Python version is bumped accordingly in CI/config or switch to `Optional[...]` for broader compatibility.
- Given the extensive new type annotations, consider adding a mypy (or similar) check to CI to catch mismatches between annotations and actual code usage early on.
## Individual Comments
### Comment 1
<location> `src/visualizr/anitalker/model/blocks.py:27` </location>
<code_context>
@abstractmethod
- def forward(self, x, emb=None, cond=None, lateral=None):
+ def forward(self, x, emb=None, cond=None, lateral=None) -> None:
"""Apply the module to `x` given `emb` timestep embeddings."""
</code_context>
<issue_to_address>
**issue:** The forward method in TimestepBlock should return a Tensor, not None.
Returning None here could break subclasses and downstream code that expect a tensor output.
</issue_to_address>
### Comment 2
<location> `src/visualizr/anitalker/model/blocks.py:465` </location>
<code_context>
self.n_heads = n_heads
- def forward(self, qkv):
+ def forward(self, qkv: Tensor) -> Tensor:
"""
Apply QKV attention.
</code_context>
<issue_to_address>
**suggestion (bug_risk):** The forward method in QKVAttentionLegacy and QKVAttention should validate input shapes more robustly.
The current validation misses cases where width or self.n_heads are zero or negative, which may cause runtime errors. Please add checks for these conditions.
</issue_to_address>
### Comment 3
<location> `src/visualizr/anitalker/config_base.py:63` </location>
<code_context>
@lru_cache
-def jsonable(x: Any) -> bool:
+def jsonable(x) -> bool:
"""Check if the object x is JSON serializable."""
try:
</code_context>
<issue_to_address>
**issue (bug_risk):** The jsonable function always returns True, even if serialization fails.
Consider updating the function to return False when serialization fails to avoid misleading results about object serializability.
</issue_to_address>
### Comment 4
<location> `src/visualizr/anitalker/model/latentnet.py:66` </location>
<code_context>
+ self.time_embed: nn.Sequential = nn.Sequential(*layers)
+ self.layers: nn.ModuleList = nn.ModuleList([])
+
+ act: Activation | None = None
+ norm: bool | None = None
+ cond: bool | None = None
</code_context>
<issue_to_address>
**nitpick:** Redundant initialization of variables that are immediately overwritten.
Consider removing the initializations of act, norm, cond, a, b, and dropout to None, as they are immediately assigned values in the loop.
</issue_to_address>
### Comment 5
<location> `src/visualizr/anitalker/model/unet.py:272` </location>
<code_context>
)
- def forward(self, x, t, y=None, **kwargs):
+ def forward(self, x, t, y=None, **kwargs) -> "Return":
"""Apply the model to an input batch."""
if (y is not None) != (self.conf.num_classes is not None):
</code_context>
<issue_to_address>
**issue:** The return type 'Return' is not defined in this context.
Define or import 'Return', or use a more appropriate type annotation.
</issue_to_address>
### Comment 6
<location> `src/visualizr/anitalker/model/unet.py:85` </location>
<code_context>
def __init__(self, conf: BeatGANsUNetConfig) -> None:
super().__init__()
self.conf: BeatGANsUNetConfig = conf
if conf.num_heads_upsample == -1:
self.num_heads_upsample = conf.num_heads
self.dtype: th.dtype = th.float32
self.time_emb_channels: int = conf.time_embed_channels or conf.model_channels
self.time_embed = nn.Sequential(
nn.Linear(self.time_emb_channels, conf.embed_channels),
nn.SiLU(),
nn.Linear(conf.embed_channels, conf.embed_channels),
)
ch = input_ch = int(conf.channel_mult[0] * conf.model_channels)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1),
),
],
)
kwargs: dict = {
"use_condition": True,
"two_cond": conf.resnet_two_cond,
"use_zero_module": conf.resnet_use_zero_module,
"cond_emb_channels": conf.resnet_cond_channels,
}
input_block_chans: list[list[int]] = [
[] for _ in range(len(conf.channel_mult))
] # TODO
input_block_chans[0].append(ch)
# number of blocks at each resolution
self.input_num_blocks: list[int] = [0 for _ in range(len(conf.channel_mult))]
self.input_num_blocks[0] = 1
self.output_num_blocks: list[int] = [0 for _ in range(len(conf.channel_mult))]
resolution: int = conf.image_size
for level, mult in enumerate(conf.input_channel_mult or conf.channel_mult):
for _ in range(conf.num_input_res_blocks or conf.num_res_blocks):
layers = [
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
out_channels=int(mult * conf.model_channels),
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
**kwargs,
).make_model(),
]
ch = int(mult * conf.model_channels)
if resolution in conf.attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.use_new_attention_order,
),
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
input_block_chans[level].append(ch)
self.input_num_blocks[level] += 1
if level != len(conf.channel_mult) - 1:
resolution //= 2
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
out_channels=out_ch,
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
down=True,
**kwargs,
).make_model()
if conf.resblock_updown
else Downsample(
ch,
conf.conv_resample,
conf.dims,
out_ch,
),
),
)
ch: int = out_ch
input_block_chans[level + 1].append(ch)
self.input_num_blocks[level + 1] += 1
self.middle_block = TimestepEmbedSequential(
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
**kwargs,
).make_model(),
AttentionBlock(
ch,
use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.use_new_attention_order,
),
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
**kwargs,
).make_model(),
)
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(conf.channel_mult))[::-1]:
for i in range(conf.num_res_blocks + 1):
try:
ich = input_block_chans[level].pop()
except IndexError:
# this happens only when `num_res_block > num_enc_res_block`
# we will not have enough lateral (skip) connections for all decoder blocks.
ich = 0
# only direct channels when gated
# lateral channels are described here when gated.
layers = [
ResBlockConfig(
channels=ch + ich,
emb_channels=conf.embed_channels,
dropout=conf.dropout,
out_channels=int(conf.model_channels * mult),
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
has_lateral=ich > 0,
**kwargs,
).make_model(),
]
ch = int(conf.model_channels * mult)
if resolution in conf.attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint,
num_heads=self.num_heads_upsample,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.use_new_attention_order,
),
)
if level and i == conf.num_res_blocks:
resolution *= 2
out_ch = ch
layers.append(
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
out_channels=out_ch,
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
up=True,
**kwargs,
).make_model()
if conf.resblock_updown
else Upsample(
ch,
conf.conv_resample,
dims=conf.dims,
out_channels=out_ch,
),
)
self.output_blocks.append(TimestepEmbedSequential(*layers))
self.output_num_blocks[level] += 1
if conf.resnet_use_zero_module:
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(
conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1),
),
)
else:
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1),
)
</code_context>
<issue_to_address>
**issue (code-quality):** Low code quality found in BeatGANsUNetModel.\_\_init\_\_ - 11% ([`low-code-quality`](https://docs.sourcery.ai/Reference/Default-Rules/comments/low-code-quality/))
<br/><details><summary>Explanation</summary>The quality score for this function is below the quality threshold of 25%.
This score is a combination of the method length, cognitive complexity and working memory.
How can you solve this?
It might be worth refactoring this function to make it shorter and more readable.
- Reduce the function length by extracting pieces of functionality out into
their own functions. This is the most important thing you can do - ideally a
function should be less than 10 lines.
- Reduce nesting, perhaps by introducing guard clauses to return early.
- Ensure that variables are tightly scoped, so that code using related concepts
sits together within the function rather than being scattered.</details>
</issue_to_address>
### Comment 7
<location> `src/visualizr/app/builder.py:350` </location>
<code_context>
def generate_video_from_name(
self,
name: str,
infer_type: InferenceType,
audio_file: str | Path,
pose_yaw: float,
pose_pitch: float,
pose_roll: float,
face_location: float,
face_scale: float,
step_t: int,
seed: int,
face_sr: bool,
) -> Path:
"""
Generate a video for a character by name using the provided settings and audio.
Args:
name (str): The base name of the character image (without extension).
infer_type (InferenceType): The type of inference mode.
audio_file (str | Path): Url or Path to the input audio file.
face_sr (bool): Whether to apply a face super-resolution.
pose_yaw (float): Yaw angle for the character's pose.
pose_pitch (float): Pitch angle for the character's pose.
pose_roll (float): Roll angle for the character's pose.
face_location (float): Relative location parameter for a face positioning.
face_scale (float): Scaling factor for the face.
step_t (int): Number of diffusion steps.
seed (int): Random seed for reproducibility.
Returns:
Path: A path to the generated video file.
"""
if audio_file is None:
_msg = "Audio path is required."
logger.error(_msg)
raise Error(_msg)
if name not in self._get_character_names():
_msg = f"Character '{name}' not found."
logger.error(_msg)
raise Error(_msg)
if isinstance(audio_file, str) and audio_file.startswith(
(
"http://",
"https://",
),
):
audio_file = self._download_audio(URL(audio_file))
if not isinstance(audio_file, Path):
audio_file = Path(audio_file)
if not audio_file.exists():
_msg = f"Audio path '{audio_file}' does not exist or is invalid."
logger.error(_msg)
raise Error(_msg)
return self.generate_video(
infer_type,
self._get_image_path(name),
audio_file.as_posix(),
pose_yaw,
pose_pitch,
pose_roll,
face_location,
face_scale,
step_t,
seed,
face_sr,
)
</code_context>
<issue_to_address>
**issue (code-quality):** Extract duplicate code into method ([`extract-duplicate-method`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/extract-duplicate-method/))
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
|
|
||
| @abstractmethod | ||
| def forward(self, x, emb=None, cond=None, lateral=None): | ||
| def forward(self, x, emb=None, cond=None, lateral=None) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue: The forward method in TimestepBlock should return a Tensor, not None.
Returning None here could break subclasses and downstream code that expect a tensor output.
| self.n_heads = n_heads | ||
|
|
||
| def forward(self, qkv): | ||
| def forward(self, qkv: Tensor) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (bug_risk): The forward method in QKVAttentionLegacy and QKVAttention should validate input shapes more robustly.
The current validation misses cases where width or self.n_heads are zero or negative, which may cause runtime errors. Please add checks for these conditions.
|
|
||
| @lru_cache | ||
| def jsonable(x: Any) -> bool: | ||
| def jsonable(x) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (bug_risk): The jsonable function always returns True, even if serialization fails.
Consider updating the function to return False when serialization fails to avoid misleading results about object serializability.
| self.time_embed: nn.Sequential = nn.Sequential(*layers) | ||
| self.layers: nn.ModuleList = nn.ModuleList([]) | ||
|
|
||
| act: Activation | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: Redundant initialization of variables that are immediately overwritten.
Consider removing the initializations of act, norm, cond, a, b, and dropout to None, as they are immediately assigned values in the loop.
| ) | ||
|
|
||
| def forward(self, x, t, y=None, **kwargs): | ||
| def forward(self, x, t, y=None, **kwargs) -> "Return": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue: The return type 'Return' is not defined in this context.
Define or import 'Return', or use a more appropriate type annotation.
|
|
||
| class BeatGANsUNetModel(nn.Module): | ||
| def __init__(self, conf: BeatGANsUNetConfig): | ||
| def __init__(self, conf: BeatGANsUNetConfig) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (code-quality): Low code quality found in BeatGANsUNetModel.__init__ - 11% (low-code-quality)
Explanation
The quality score for this function is below the quality threshold of 25%.This score is a combination of the method length, cognitive complexity and working memory.
How can you solve this?
It might be worth refactoring this function to make it shorter and more readable.
- Reduce the function length by extracting pieces of functionality out into
their own functions. This is the most important thing you can do - ideally a
function should be less than 10 lines. - Reduce nesting, perhaps by introducing guard clauses to return early.
- Ensure that variables are tightly scoped, so that code using related concepts
sits together within the function rather than being scattered.
| logger.error(_msg) | ||
| raise Error(_msg) | ||
| if name not in self._get_character_names(): | ||
| _msg = f"Character '{name}' not found." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (code-quality): Extract duplicate code into method (extract-duplicate-method)
|
Here's the code health analysis summary for commits Analysis Summary
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds comprehensive type hints throughout the codebase to improve type safety and code maintainability. The changes include adding return type annotations, parameter type hints, and proper typing for class attributes.
Key Changes
- Added type hints to function signatures (parameters and return types)
- Introduced proper type annotations for class attributes
- Centralized
InferenceTypedefinition insrc/visualizr/app/types.py - Improved union type syntax using
|operator - Enhanced error message handling with proper typing
Reviewed Changes
Copilot reviewed 24 out of 24 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
| src/visualizr/app/settings.py | Updated field types to use union syntax and renamed/documented validator method |
| src/visualizr/app/builder.py | Added type validation for Path objects and improved code formatting |
| src/visualizr/anitalker/utils.py | Added comprehensive type hints and reorganized imports with TYPE_CHECKING |
| src/visualizr/anitalker/templates.py | Added return type annotations for configuration functions |
| src/visualizr/anitalker/renderer.py | Fixed NotImplementedError usage |
| src/visualizr/anitalker/networks/styledecoder.py | Added extensive type hints to neural network classes |
| src/visualizr/anitalker/networks/encoder.py | Added type hints and improved formatting |
| src/visualizr/anitalker/networks/discriminator.py | Added type hints and removed extra import |
| src/visualizr/anitalker/model/unet_autoenc.py | Added type hints for model classes and methods |
| src/visualizr/anitalker/model/unet.py | Added type hints and fixed comment formatting |
| src/visualizr/anitalker/model/seq2seq.py | Added type hints for LSTM and model classes |
| src/visualizr/anitalker/model/nn.py | Added type hints for utility functions |
| src/visualizr/anitalker/model/latentnet.py | Added comprehensive type hints to latent network classes |
| src/visualizr/anitalker/model/blocks.py | Added type hints and improved error handling |
| src/visualizr/anitalker/liamodel.py | Added type hints and fixed string concatenation |
| src/visualizr/anitalker/face_sr/videoio.py | Changed while loop condition to use True instead of 1 |
| src/visualizr/anitalker/face_sr/face_enhancer.py | Added type hints for function signatures |
| src/visualizr/anitalker/experiment.py | Replaced copy with deepcopy and added type hints |
| src/visualizr/anitalker/diffusion/resample.py | Added type hints to sampler classes |
| src/visualizr/anitalker/diffusion/diffusion.py | Added comprehensive type hints and improved error messages |
| src/visualizr/anitalker/diffusion/base.py | Added type hints throughout diffusion base classes |
| src/visualizr/anitalker/config_base.py | Added type hints and changed logging level for extra keys |
| src/visualizr/anitalker/config.py | Centralized InferenceType and added type hints |
| src/visualizr/anitalker/choices.py | Added type annotations to Enum values |
Comments suppressed due to low confidence (2)
src/visualizr/anitalker/diffusion/base.py:667
- Missing space before inline comment. Should be
# TODO: npinstead of#TODO: np.
yield out
src/visualizr/anitalker/networks/discriminator.py:3
- Import of 'T' is not used.
from regex import T
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -1,41 +1,52 @@ | |||
| import math | |||
|
|
|||
| from regex import T | |||
Copilot
AI
Oct 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused import T from regex module. This appears to be an accidental import that should be removed.
| from regex import T |
| size: int = 256, | ||
| style_dim: int = 512, | ||
| motion_dim: int = 20, | ||
| channel_multiplier: int = 1, | ||
| blur_kernel: list | None = None, | ||
| fusion_type: str = "", |
Copilot
AI
Oct 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing spaces around = in default parameter assignments. Should be size: int = 256 instead of size: int=256 for consistency with PEP 8.
| f"Wrong parameter length: {orig_name}, " | ||
| + f"model: {self_state[name].size()}, " | ||
| f"model: {self_state[name].size()}, " | ||
| + f"loaded: {state[orig_name].size()}" |
Copilot
AI
Oct 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent string concatenation: Line 49 uses implicit concatenation while line 50 uses explicit + operator. Should use consistent style (preferably implicit concatenation for both lines).
| + f"loaded: {state[orig_name].size()}" | |
| f"loaded: {state[orig_name].size()}" |
| dec_cond_emb: Tensor | None = cond_emb | ||
|
|
||
| hs = [[] for _ in range(len(self.conf.channel_mult))] | ||
| hs: list[list[Tensor]] = [ |
Copilot
AI
Oct 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing space before inline comment. Should be #TODO: → # TODO: to follow PEP 8 style guidelines.
| :param qkv: An `[N x (3 x H x C) x T]` tensor of Qs, Ks, and Vs. | ||
| :return: An `[N x (H x C) x T]` tensor after attention. |
Copilot
AI
Oct 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Changed multiplication symbol from × to x in documentation. While both are readable, × (U+00D7) is the proper mathematical multiplication symbol and was more semantically correct in the original.
| :param qkv: An `[N x (3 x H x C) x T]` tensor of Qs, Ks, and Vs. | |
| :return: An `[N x (H x C) x T]` tensor after attention. | |
| :param qkv: An `[N × (3 × H × C) × T]` tensor of Qs, Ks, and Vs. | |
| :return: An `[N × (H × C) × T]` tensor after attention. |
| shape=None, | ||
| noise=None, | ||
| clip_denoised=True, | ||
| clip_denoised: bool = True, |
Copilot
AI
Oct 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing space before = in default parameter. Should be device: th.device | None = None.
| try: | ||
| dumps(x) | ||
| return True | ||
| except TypeError: | ||
| return False | ||
| return True |
Copilot
AI
Oct 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unreachable code due to control flow change. The return True at line 69 will never execute if dumps(x) succeeds because the function would have already returned in the original code. The correct implementation should have return True inside the try block before any exception handling.
| _msg: str = f"loading extra '{k}'" | ||
| if strict: | ||
| raise ValueError(f"loading extra '{k}'") | ||
| _msg = f"loading extra '{k}'" | ||
| logger.info(_msg) | ||
| Info(_msg) | ||
| raise ValueError(_msg) | ||
| logger.warning(_msg) |
Copilot
AI
Oct 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed from logger.info() to logger.warning() but the logic is incorrect. When strict=False, the warning is logged, but when strict=True, an exception is raised and the warning is never logged. The original code using Info(_msg) after logging makes more sense for the non-strict case.
| self.ema_model = copy.deepcopy(self.model) | ||
| self.ema_model.requires_grad_(False) | ||
| self.ema_model: DiffusionPredictor = deepcopy(self.model) | ||
| self.ema_model.requires_grad_(requires_grad=False) |
Copilot
AI
Oct 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Using named parameter requires_grad=False is redundant. The requires_grad_() method takes the boolean value as the first positional argument, so this should be self.ema_model.requires_grad_(False) for cleaner code.
| self.ema_model.requires_grad_(requires_grad=False) | |
| self.ema_model.requires_grad_(False) |
| @@ -1,5 +1,5 @@ | |||
| from dataclasses import dataclass | |||
| from typing import Literal | |||
| from typing import Literal, Self | |||
Copilot
AI
Oct 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'Literal' is not used.
| from typing import Literal, Self | |
| from typing import Self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request does a great job of adding type annotations and refactoring the code for better clarity and type safety, which significantly improves maintainability. The changes are extensive and well-executed. I've provided a few suggestions to fix some incorrect or missing type hints to further align with the PR's goals. Additionally, I've identified a significant amount of duplicated code across several files in the networks directory and recommended refactoring it into a shared utility module. Addressing this will greatly improve the codebase's maintainability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 14
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (15)
src/visualizr/app/settings.py (1)
110-132: Redundant validation logic—Pydantic already validates file existence.With
FilePath | Nonefield types, Pydantic automatically validates file existence during field validation (before thismode="after"validator runs). If a non-None path doesn't exist, Pydantic raises a validation error before reaching this method. The manual.exists()checks here are redundant and will never trigger.Additionally, the
Error(_msg)calls on lines 125 and 130 create Gradio Error objects but don't raise or use them, making them ineffective.Consider either:
- Remove the redundant validator entirely (preferred if Pydantic's default validation messages suffice):
- @model_validator(mode="after") - def check_missing_paths(self) -> "ModelSettings": - """ - Validate that the image and audio paths exist if provided. - - Checks the existence of the image_path and audio_path attributes. - - Returns: - ModelSettings: The validated ModelSettings instance. - - Raises: - FileNotFoundError: If the image_path or audio_path does not exist. - """ - if self.image_path and not self.image_path.exists(): - _msg = f"Image path does not exist: {self.image_path}" - logger.error(_msg) - Error(_msg) - raise FileNotFoundError(_msg) - if self.audio_path and not self.audio_path.exists(): - _msg = f"Audio path does not exist: {self.audio_path}" - logger.error(_msg) - Error(_msg) - raise FileNotFoundError(_msg) - return self
- Or, if custom error handling is needed, use a
mode="before"validator to provide custom error messages before Pydantic's field validation:@model_validator(mode="before") @classmethod def check_missing_paths(cls, values: dict) -> dict: """Validate paths with custom error messages.""" for field_name in ["image_path", "audio_path"]: path = values.get(field_name) if path is not None: path_obj = Path(path) if not path_obj.exists(): _msg = f"{field_name.replace('_', ' ').title()} does not exist: {path}" logger.error(_msg) raise ValueError(_msg) return valuessrc/visualizr/anitalker/model/latentnet.py (1)
164-179: Add parameter and return type annotations for consistency.The
forwardmethod lacks type annotations for its parameters and return value, which is inconsistent with the comprehensive type annotations added throughout the PR.Apply this diff to add complete type annotations:
- def forward(self, x, cond=None): + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor:src/visualizr/anitalker/liamodel.py (1)
41-52: Do not instantiate gradio.Error without raising; add safe state_dict handling and no_grad.
- Instantiating Error(...) without raise has no effect; errors are silently ignored.
- torch.load Lightning checkpoints often wrap weights under "state_dict".
- Use torch.no_grad() when in-place copying parameters to avoid autograd side-effects.
Suggested fix:
- if name not in self_state: - Error(f"{orig_name} is not in the model.") - # You can ignore those errors as some parameters are only used for training. - continue + if name not in self_state: + # Some params are only used for training; log and skip. + from visualizr.app.logger import logger + logger.warning("Skipping unknown parameter: %s", orig_name) + continue - if self_state[name].size() != state[orig_name].size(): - Error( - f"Wrong parameter length: {orig_name}, " - f"model: {self_state[name].size()}, " - + f"loaded: {state[orig_name].size()}" - ) - continue + if self_state[name].size() != state[orig_name].size(): + from visualizr.app.logger import logger + logger.warning( + "Skipping size-mismatched param %s (model=%s, loaded=%s)", + orig_name, self_state[name].size(), state[orig_name].size() + ) + continue - self_state[name].copy_(param) + with torch.no_grad(): + self_state[name].copy_(param)Additionally, at Line 37, consider:
- state = load(lia_pretrained_model_path, map_location="cpu") + state = load(lia_pretrained_model_path, map_location="cpu") + if isinstance(state, dict) and "state_dict" in state: + state = state["state_dict"]src/visualizr/anitalker/networks/styledecoder.py (1)
298-303: Crash when bias=False:self.bias * self.lr_mulon None.Guard bias usage in both activation and non-activation paths.
- if self.activation: - out = linear(_input, self.weight * self.scale) - out = fused_leaky_relu(out, self.bias * self.lr_mul) - else: - out = linear(_input, self.weight * self.scale, self.bias * self.lr_mul) + bias_term = None if self.bias is None else self.bias * self.lr_mul + if self.activation: + out = linear(_input, self.weight * self.scale) + out = ( + fused_leaky_relu(out, bias_term) + if bias_term is not None + else leaky_relu(out) + ) + else: + out = linear(_input, self.weight * self.scale, bias=bias_term) return outsrc/visualizr/anitalker/renderer.py (1)
48-55: Guard against Nonemodel_typeto avoid AttributeError.
TrainConfig.model_typeis annotated asModelType | None. Calling.has_autoenc()when it is None will crash.- if not conf.model_type.has_autoenc(): + if not (conf.model_type and conf.model_type.has_autoenc()): msg: str = ( "TrainMode.diffusion requires an " "autoencoder-capable `model_type`; " f"got {conf.model_type!r}" ) raise ValueError(msg)src/visualizr/anitalker/model/unet.py (1)
461-466: Fix AttributeError: use conf.model_channelsself.model_channels is undefined; should reference config.
- if self.conf.use_time_condition: - emb = self.time_embed(timestep_embedding(t, self.model_channels)) + if self.conf.use_time_condition: + emb = self.time_embed(timestep_embedding(t, self.conf.model_channels))src/visualizr/anitalker/model/seq2seq.py (1)
229-252: Fix shape mismatch in concatenate: repeat noisy_feature across timenoisy_feature is (B, 128); others are (B, T, 128). cat will fail.
- noisy_feature = self.noisy_encoder(noisy_x) + noisy_feature = ( + self.noisy_encoder(noisy_x) + .unsqueeze(1) + .repeat(1, x.size(1), 1) + )src/visualizr/anitalker/model/unet_autoenc.py (2)
38-58: enc_pool default conflicts with encoder supportBeatGANsEncoderModel currently only handles pool=="adaptivenonzero" and raises otherwise. BeatGANsAutoencConfig defaults enc_pool="depthconv", which will raise at init.
- enc_pool: str = "depthconv" + enc_pool: str = "adaptivenonzero"If "depthconv" is required, extend BeatGANsEncoderModel to support it instead.
160-165: Guard x=None path (h becomes None, then used)When x is None, h is set to None but later passed through blocks and to self.out, causing a failure. Provide a safe fallback or prohibit x=None.
- else: - # no lateral connections - # happen when training only the autoencoder - h = None - hs = [[] for _ in range(len(self.conf.channel_mult))] + else: + msg = "`x` must be provided for autoencoder forward" + raise ValueError(msg)Alternatively, initialize h from an encoded prior if that’s the intended flow.
src/visualizr/anitalker/networks/encoder.py (3)
288-325: Blocker: fusion_type default + bool/str mismatch raises on construction and breaks Encoder
- EncoderApp expects str but Encoder passes a bool. With default "", the check always raises ValueError unless explicitly "weighted_sum". This makes Encoder unusable with default weighted_sum=False.
Apply this minimal fix to accept both bool/str and avoid hard failure:
- def __init__(self, size, w_dim: int = 512, fusion_type: str = "") -> None: + def __init__(self, size, w_dim: int = 512, fusion_type: bool | str | None = None) -> None: @@ - self.fusion_type: str = fusion_type - - if self.fusion_type != "weighted_sum": - msg: str = ( - f"Unsupported `fusion_type`: {self.fusion_type}. " - "Expected 'weighted_sum'." - ) - raise ValueError(msg) + # Normalize config: enable only when True or explicit string. + self.fusion_type: str | None = ( + "weighted_sum" if (fusion_type is True or fusion_type == "weighted_sum") else None + )
332-333: Blocker: WeightedSumLayer length mismatch silently drops featurespooled_h_lists length equals len(self.convs) (=log_size). WeightedSumLayer defaults to 8, so zip(strict=False) ignores extra tensors; gradients for ignored layers are lost.
- self.ws = WeightedSumLayer() + self.ws = WeightedSumLayer(num_tensors=len(self.convs))Optionally, validate at runtime:
+ # Sanity check to prevent silent truncation + assert len(self.convs) == self.ws.weights.numel(), "ws size must match conv count"Also applies to: 339-353
3-3: Remove Gradio UI side-effects from core network moduleHard dependency on gradio.Info in model code couples UI and backend, breaks headless usage, and adds import overhead.
-from gradio import Info @@ - _msg = "HAL layer is enabled!" - logger.info(_msg) - Info(_msg) + logger.info("HAL layer is enabled!")If user feedback is needed, emit via a callback passed from the UI layer or log only.
Also applies to: 325-327
src/visualizr/anitalker/diffusion/base.py (2)
139-144: Always enrich model_kwargs; current logic skips when a dict is providedx_start/cond are only set when model_kwargs is None. If a caller passes {}, they’re missed.
- if model_kwargs is None: - model_kwargs: dict = {} - if self.conf.model_type.has_autoenc(): - model_kwargs["x_start"] = x_start - model_kwargs["cond"] = cond + model_kwargs = dict(model_kwargs or {}) + if self.conf.model_type.has_autoenc(): + model_kwargs.setdefault("x_start", x_start) + model_kwargs.setdefault("cond", cond)
344-356: Fix crash when cond_fn is used with model_kwargs=NonePassing **model_kwargs with None raises. Normalize to {}.
- def condition_mean( - self, cond_fn, p_mean_var, x, t: th.Tensor, model_kwargs: dict | None = None - ): + def condition_mean( + self, cond_fn, p_mean_var, x, t: th.Tensor, model_kwargs: dict | None = None + ): @@ - gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) + model_kwargs = model_kwargs or {} + gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) return p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() @@ - def condition_score( - self, cond_fn, p_mean_var, x, t: th.Tensor, model_kwargs: dict | None = None - ): + def condition_score( + self, cond_fn, p_mean_var, x, t: th.Tensor, model_kwargs: dict | None = None + ): @@ - eps = eps - (1 - alpha_bar).sqrt() * cond_fn( + model_kwargs = model_kwargs or {} + eps = eps - (1 - alpha_bar).sqrt() * cond_fn( x, self._scale_timesteps(t), - **model_kwargs, + **model_kwargs, )Also applies to: 358-370
src/visualizr/anitalker/face_sr/face_enhancer.py (1)
171-176: Remove the invalid generic type parameter from GFPGANer.The annotation
GFPGANer[str]is incorrect. GFPGAN (the official PyPI package) does not ship a py.typed marker or bundled type hints, which meansGFPGANeris not defined as a generic class. Remove the type parameter:- restorer: GFPGANer[str] = GFPGANer( + restorer: GFPGANer = GFPGANer(
♻️ Duplicate comments (1)
src/visualizr/anitalker/model/blocks.py (1)
27-27: TimestepBlock.forward should return Tensor, not None.This base type signature conflicts with subclasses and prior reviews flagged it.
- def forward(self, x, emb=None, cond=None, lateral=None) -> None: + def forward(self, x, emb=None, cond=None, lateral=None) -> Tensor:
🧹 Nitpick comments (25)
src/visualizr/app/builder.py (2)
87-112: Refactor: Move Path conversions before existence checks to avoid redundant object creation.The
Path(image_path)object is created at line 87 for the existence check, thenimage_pathis converted toPathagain at lines 111-112. This creates the Path object twice.Consider moving the conversion to the beginning of the method:
) -> Path: + # Convert to Path early to avoid redundant conversions + if not isinstance(image_path, Path): + image_path = Path(image_path) + if not isinstance(audio_path, Path): + audio_path = Path(audio_path) + if image_path is None or not Path(image_path).exists():Then remove the later conversions at lines 111-112 and 281-282.
360-369: Refactor: Pass Path directly instead of converting to string.After ensuring
audio_fileis aPathobject (lines 360-362), it's converted back to a string with.as_posix()at line 369, only to be converted back toPathagain insidegenerate_videoat lines 281-282. This is inefficient.Apply this diff to pass the Path directly:
return self.generate_video( infer_type, self._get_image_path(name), - audio_file.as_posix(), + audio_file, pose_yaw, pose_pitch,src/visualizr/anitalker/model/latentnet.py (1)
112-112: Consider renaming to avoid variable shadowing.The parameter
tis immediately shadowed by reassigning it to the result oftimestep_embedding. While this works, using a distinct name liket_embwould improve clarity.Apply this diff:
- t: torch.Tensor = timestep_embedding(t, self.conf.num_time_emb_channels) - cond = self.time_embed(t) + t_emb: torch.Tensor = timestep_embedding(t, self.conf.num_time_emb_channels) + cond = self.time_embed(t_emb)src/visualizr/anitalker/liamodel.py (1)
26-33: Add type hints to public methods for consistency.Annotate inputs/outputs for get_start_direction_code and render (e.g., Tensor types, Tuple[Tensor, Tensor, list[Tensor]] as applicable).
src/visualizr/anitalker/config_base.py (1)
51-59: as_dict_jsonable return typing is fine; consider key/value typing.Optionally narrow to dict[str, Any] for better static checking.
src/visualizr/anitalker/networks/discriminator.py (2)
96-101: Broadenmake_kernel/Blur input typing (optional).You pass lists to Blur/make_kernel but annotate
kernel: Tensor. Prefer acceptingSequence[float] | Tensorfor accuracy.
19-24: Annotatechannel: intin FusedLeakyReLU for clarity.Minor typing improvement:
channel: int.src/visualizr/anitalker/networks/styledecoder.py (2)
160-174: Type the kernel input asSequence[float] | Tensorand avoid reusing the param name.Improves correctness of hints and readability.
-def make_kernel(k: Tensor) -> Tensor: # TODO: type hint +from typing import Sequence + +def make_kernel(kernel: Sequence[float] | Tensor) -> Tensor: @@ - k: Tensor = torch.tensor(k, dtype=torch.float32) + k: Tensor = torch.tensor(kernel, dtype=torch.float32)
192-203: Consider specifying align_corners in grid_sample to avoid warnings/ambiguity.Optional, e.g.,
grid_sample(feat, flow, align_corners=True).src/visualizr/anitalker/model/blocks.py (2)
141-147: Align conv type annotations with torch.nn. (follow-up to import fix).*Where you annotate conv layers, prefer
nn.Conv1d | nn.Conv2d | nn.Conv3d. Example:- conv: nn.Conv1d | Conv2d | Conv3d = conv_nd( + conv: nn.Conv1d | nn.Conv2d | nn.Conv3d = conv_nd( @@ - self.skip_connection: nn.Conv1d | Conv2d | Conv3d = conv_nd( + self.skip_connection: nn.Conv1d | nn.Conv2d | nn.Conv3d = conv_nd( @@ - self.qkv: nn.Conv1d | Conv2d | Conv3d = conv_nd(1, channels, channels * 3, 1) + self.qkv: nn.Conv1d | nn.Conv2d | nn.Conv3d = conv_nd(1, channels, channels * 3, 1) @@ - self.proj_out: nn.Conv1d | Conv2d | Conv3d = zero_module( - conv_nd(1, channels, channels, 1) - ) + self.proj_out: nn.Conv1d | nn.Conv2d | nn.Conv3d = zero_module( + conv_nd(1, channels, channels, 1) + ) @@ - self.conv: nn.Conv1d | Conv2d | Conv3d = conv_nd( + self.conv: nn.Conv1d | nn.Conv2d | nn.Conv3d = conv_nd(Also applies to: 181-187, 433-444, 350-353, 385-393
305-312: Support list-like scale_bias in apply_conditions (optional).Current logic only handles numeric scale_bias. If lists are intended, add else branch and length check.
- if isinstance(scale_bias, Number): - biases = [scale_bias] * len(scale_shifts) + if isinstance(scale_bias, Number): + biases = [scale_bias] * len(scale_shifts) + else: + biases = list(scale_bias) + if len(biases) != len(scale_shifts): + raise ValueError("scale_bias length must match number of conditions")src/visualizr/anitalker/renderer.py (1)
46-47: Optional: include a message in NotImplementedError.Raising with a short message improves debuggability, e.g., why non-diffusion mode is unsupported here.
- raise NotImplementedError + raise NotImplementedError("render_condition only supports TrainMode.diffusion")src/visualizr/anitalker/model/nn.py (2)
22-33: Minor typing polish forconv_nd.Annotate
dimsasintfor clarity.-def conv_nd(dims, *args, **kwargs) -> nn.Conv1d | nn.Conv2d | nn.Conv3d: +def conv_nd(dims: int, *args, **kwargs) -> nn.Conv1d | nn.Conv2d | nn.Conv3d:
90-94: Add return type totorch_checkpointfor better static checking.Use a broad
Anyto reflect passthrough behavior.-from torch.utils.checkpoint import checkpoint +from torch.utils.checkpoint import checkpoint +from typing import Any @@ -def torch_checkpoint(func, args, flag, preserve_rng_state: bool = False): +def torch_checkpoint(func, args, flag, preserve_rng_state: bool = False) -> Any:src/visualizr/anitalker/config.py (1)
141-146: Alignuse_timestepstype with producer (space_timesteps).Producer returns a
set[int]. Consider updatingSpacedDiffusionBeatGansConfig.use_timestepsto accept aCollection[int](see suggested change in diffusion.py) to avoid type-checker friction. No code change needed here if that file is updated.src/visualizr/anitalker/utils.py (2)
52-55: Prefer keyword args and avoid parameterizing third‑party classes in annotations
- Use method="compose" for clarity.
- Variable annotations like AudioFileClip[Path] are non-standard; most type checkers will flag them. Keep them unparameterized.
- video: VideoClip | CompositeVideoClip = concatenate_videoclips(clips, "compose") - audio: AudioFileClip[Path] = AudioFileClip(audio_path) - final_video: VideoClip = video.set_audio(audio) + video = concatenate_videoclips(clips, method="compose") + audio = AudioFileClip(audio_path) + final_video = video.set_audio(audio)
164-191: Close media resources deterministically (avoid file handle leaks/locks)Use context managers for VideoFileClip/AudioFileClip; otherwise tmp file unlink can fail on Windows.
- video_clip: VideoFileClip[str] = VideoFileClip( - tmp_predicted_video_512_path.as_posix(), - ) - audio_clip: AudioFileClip[str] = AudioFileClip( - predicted_video_256_path.as_posix(), - ) - final_clip: VideoClip = video_clip.set_audio(audio_clip) - final_clip.write_videofile( - predicted_video_512_path.as_posix(), - codec="libx264", - audio_codec="aac", - ) + with VideoFileClip(tmp_predicted_video_512_path.as_posix()) as video_clip, \ + AudioFileClip(predicted_video_256_path.as_posix()) as audio_clip: + final_clip = video_clip.set_audio(audio_clip) + final_clip.write_videofile( + predicted_video_512_path.as_posix(), + codec="libx264", + audio_codec="aac", + )src/visualizr/anitalker/model/unet.py (1)
272-277: Unused kwargs triggers ARG002; either use or underscore itSilence lint while preserving signature compatibility.
- def forward(self, x, t, y=None, **kwargs) -> "Return": + def forward(self, x, t, y=None, **_kwargs) -> "Return":src/visualizr/anitalker/model/unet_autoenc.py (1)
214-217: Type hints OK; consider from future import annotations for safetyIf runtime evaluates annotations (older Python), forward-referenced types are safer stringified. Optional.
+from __future__ import annotationssrc/visualizr/anitalker/choices.py (2)
6-11: Prefer StrEnum for string enums (cleaner typing/str interop)Annotated enum members can trip some type checkers. StrEnum gives explicit str behavior without member annotations.
-from enum import Enum +from enum import Enum, StrEnum ... -class TrainMode(Enum): +class TrainMode(StrEnum): @@ - manipulate: str = "manipulate" + manipulate = "manipulate" @@ - diffusion: str = "diffusion" + diffusion = "diffusion"Apply similarly to other stringy enums (ModelType, ModelName, …).
72-83: get_act return union is fine; you can tighten the return Protocol to nn.ModuleKeeps surface simpler and future-proof.
- def get_act(self) -> Identity | ReLU | LeakyReLU | SiLU | Tanh: + def get_act(self) -> nn.Module:src/visualizr/anitalker/networks/encoder.py (1)
98-104: Typing nit: accept list-like kernels without misleading Tensor-only hintsmake_kernel and Blur accept Python lists at call sites. Update annotations for clarity.
-def make_kernel(k: Tensor) -> Tensor: +from typing import Sequence +def make_kernel(k: Tensor | Sequence[float]) -> Tensor: @@ -class Blur(nn.Module): - def __init__( - self, - kernel: Tensor, +class Blur(nn.Module): + def __init__( + self, + kernel: Tensor | Sequence[float],Also applies to: 107-113
src/visualizr/anitalker/diffusion/base.py (3)
291-304: Simplify redundant eps-branch logicTwo nested checks for the same condition; keep one.
- if self.model_mean_type in [ModelMeanType.eps]: - if self.model_mean_type == ModelMeanType.eps: # TODO - pred_xstart = process_xstart( - self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output), - ) + if self.model_mean_type is ModelMeanType.eps: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output), + ) else: raise NotImplementedError(self.model_mean_type)
762-763: Style: inline exception message to satisfy TRY003Avoid pre-binding msg just to raise. Inline the f-string.
- msg: str = f"unknown beta schedule: {schedule_name}" - raise NotImplementedError(msg) + raise NotImplementedError(f"unknown beta schedule: {schedule_name}")As per static analysis hints (TRY003).
97-103: Optional: guard T==1 edge case in posterior_log_variance_clippednp.append(self.posterior_variance[1], ...) will IndexError when num_timesteps == 1.
- self.posterior_log_variance_clipped = np.log( - np.append(self.posterior_variance[1], self.posterior_variance[1:]), - ) + if self.num_timesteps > 1: + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]), + ) + else: + self.posterior_log_variance_clipped = np.log(self.posterior_variance.copy())
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (24)
src/visualizr/anitalker/choices.py(2 hunks)src/visualizr/anitalker/config.py(5 hunks)src/visualizr/anitalker/config_base.py(2 hunks)src/visualizr/anitalker/diffusion/base.py(20 hunks)src/visualizr/anitalker/diffusion/diffusion.py(5 hunks)src/visualizr/anitalker/diffusion/resample.py(2 hunks)src/visualizr/anitalker/experiment.py(2 hunks)src/visualizr/anitalker/face_sr/face_enhancer.py(5 hunks)src/visualizr/anitalker/face_sr/videoio.py(1 hunks)src/visualizr/anitalker/liamodel.py(2 hunks)src/visualizr/anitalker/model/blocks.py(16 hunks)src/visualizr/anitalker/model/latentnet.py(5 hunks)src/visualizr/anitalker/model/nn.py(7 hunks)src/visualizr/anitalker/model/seq2seq.py(9 hunks)src/visualizr/anitalker/model/unet.py(13 hunks)src/visualizr/anitalker/model/unet_autoenc.py(8 hunks)src/visualizr/anitalker/networks/discriminator.py(6 hunks)src/visualizr/anitalker/networks/encoder.py(13 hunks)src/visualizr/anitalker/networks/styledecoder.py(26 hunks)src/visualizr/anitalker/renderer.py(1 hunks)src/visualizr/anitalker/templates.py(3 hunks)src/visualizr/anitalker/utils.py(11 hunks)src/visualizr/app/builder.py(4 hunks)src/visualizr/app/settings.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (15)
src/visualizr/anitalker/templates.py (1)
src/visualizr/anitalker/config.py (1)
TrainConfig(31-227)
src/visualizr/anitalker/utils.py (2)
src/visualizr/anitalker/config.py (1)
TrainConfig(31-227)src/visualizr/anitalker/experiment.py (1)
LitModel(18-88)
src/visualizr/anitalker/model/unet_autoenc.py (1)
src/visualizr/anitalker/model/blocks.py (8)
forward(27-28)forward(37-43)forward(189-201)forward(354-364)forward(402-406)forward(446-447)forward(469-487)forward(497-520)
src/visualizr/anitalker/diffusion/diffusion.py (1)
src/visualizr/anitalker/diffusion/base.py (2)
make_sampler(48-56)GaussianDiffusionBeatGans(59-668)
src/visualizr/anitalker/config.py (6)
src/visualizr/anitalker/config_base.py (1)
BaseConfig(12-59)src/visualizr/anitalker/diffusion/diffusion.py (1)
SpacedDiffusionBeatGansConfig(72-87)src/visualizr/anitalker/diffusion/resample.py (1)
UniformSampler(45-50)src/visualizr/anitalker/model/unet_autoenc.py (1)
BeatGANsAutoencConfig(16-24)src/visualizr/anitalker/model/unet.py (1)
BeatGANsUNetConfig(25-81)src/visualizr/anitalker/choices.py (2)
ModelName(25-29)ModelType(13-22)
src/visualizr/anitalker/model/latentnet.py (2)
src/visualizr/anitalker/choices.py (2)
get_act(72-83)Activation(65-83)src/visualizr/anitalker/model/nn.py (1)
timestep_embedding(67-87)
src/visualizr/anitalker/experiment.py (5)
src/visualizr/anitalker/choices.py (1)
TrainMode(6-10)src/visualizr/anitalker/config.py (4)
TrainConfig(31-227)make_diffusion_conf(141-142)make_eval_diffusion_conf(144-145)make_t_sampler(136-139)src/visualizr/anitalker/diffusion/diffusion.py (2)
SpacedDiffusionBeatGans(90-108)make_sampler(86-87)src/visualizr/anitalker/renderer.py (1)
render_condition(10-67)src/visualizr/anitalker/diffusion/resample.py (1)
UniformSampler(45-50)
src/visualizr/anitalker/networks/styledecoder.py (2)
src/visualizr/anitalker/networks/discriminator.py (8)
fused_leaky_relu(9-15)upfirdn2d(75-93)make_kernel(96-101)forward(30-31)forward(118-119)forward(127-128)forward(152-159)Blur(104-119)src/visualizr/anitalker/networks/encoder.py (15)
fused_leaky_relu(10-16)upfirdn2d(77-95)make_kernel(98-103)forward(31-32)forward(120-121)forward(129-130)forward(154-161)forward(189-195)forward(266-271)forward(279-284)forward(334-358)forward(385-388)forward(413-472)Blur(106-121)EqualLinear(170-200)
src/visualizr/anitalker/model/unet.py (1)
src/visualizr/anitalker/model/blocks.py (10)
ResBlock(79-256)make_model(75-76)forward(27-28)forward(37-43)forward(189-201)forward(354-364)forward(402-406)forward(446-447)forward(469-487)forward(497-520)
src/visualizr/anitalker/diffusion/base.py (2)
src/visualizr/anitalker/diffusion/diffusion.py (1)
make_sampler(86-87)src/visualizr/anitalker/choices.py (3)
ModelMeanType(32-36)ModelVarType(39-50)LossType(53-55)
src/visualizr/anitalker/networks/encoder.py (2)
src/visualizr/anitalker/networks/discriminator.py (9)
fused_leaky_relu(9-15)FusedLeakyReLU(18-31)upfirdn2d_native(34-72)upfirdn2d(75-93)make_kernel(96-101)Blur(104-119)ScaledLeakyReLU(122-128)ConvLayer(162-207)EqualConv2d(131-159)src/visualizr/anitalker/networks/styledecoder.py (24)
fused_leaky_relu(21-39)FusedLeakyReLU(42-64)forward(54-64)forward(188-189)forward(201-202)forward(228-244)forward(288-303)forward(330-340)forward(393-440)forward(448-449)forward(457-459)forward(490-494)forward(561-567)forward(590-614)forward(622-632)forward(712-748)upfirdn2d_native(67-123)upfirdn2d(126-157)make_kernel(160-174)Blur(192-202)ScaledLeakyReLU(317-340)ConvLayer(497-541)EqualConv2d(205-250)EqualLinear(253-314)
src/visualizr/anitalker/model/seq2seq.py (1)
src/visualizr/anitalker/networks/encoder.py (1)
Encoder(391-472)
src/visualizr/anitalker/model/blocks.py (1)
src/visualizr/anitalker/model/nn.py (5)
GroupNorm32(17-19)forward(18-19)normalization(57-64)conv_nd(22-32)zero_module(50-54)
src/visualizr/anitalker/model/nn.py (2)
src/visualizr/anitalker/model/blocks.py (8)
forward(27-28)forward(37-43)forward(189-201)forward(354-364)forward(402-406)forward(446-447)forward(469-487)forward(497-520)src/visualizr/anitalker/model/unet_autoenc.py (2)
forward(67-186)forward(214-217)
src/visualizr/anitalker/networks/discriminator.py (1)
src/visualizr/anitalker/networks/encoder.py (7)
fused_leaky_relu(10-16)FusedLeakyReLU(19-32)upfirdn2d_native(35-74)upfirdn2d(77-95)make_kernel(98-103)Blur(106-121)ScaledLeakyReLU(124-130)
🪛 Ruff (0.14.2)
src/visualizr/anitalker/model/unet.py
272-272: Unused method argument: kwargs
(ARG002)
src/visualizr/anitalker/diffusion/base.py
326-326: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Sourcery review
- GitHub Check: Mergify Merge Protections
- GitHub Check: Summary
🔇 Additional comments (20)
src/visualizr/anitalker/face_sr/videoio.py (1)
8-8: Good pythonic improvement!Replacing
while 1:withwhile True:makes the loop condition more explicit and aligns with Python style guidelines.src/visualizr/app/settings.py (1)
96-97: LGTM! Type annotations now properly allow None.The change from
FilePathtoFilePath | Nonecorrectly reflects that these fields can beNone, aligning with thedefault=Nonevalues. This improves type safety and clarity.src/visualizr/app/builder.py (1)
380-425: LGTM: Return type change handled correctly.The return type change from
Pathtostris properly implemented with the docstring updated (line 411) and the result converted using.as_posix()(line 425). The direct passing ofaudio_fileat line 416 avoids unnecessary conversions.src/visualizr/anitalker/model/latentnet.py (6)
14-17: LGTM: Enum type annotations are correct.The explicit string type annotations for enum members follow best practices and are consistent with the
Activationenum pattern used elsewhere in the codebase.
20-21: LGTM: Union type annotation is correct.The optional Tensor type is properly expressed using modern union syntax.
41-42: LGTM: Forward reference correctly used.The string-quoted return type annotation properly handles the forward reference to the
MLPSkipNetclass defined later in the file.
52-109: LGTM: Comprehensive type annotations.The type annotations throughout
__init__are thorough and correct. The union types for activation attributes properly match the return type ofActivation.get_act().Note: The redundant None initializations at lines 66-71 have already been flagged in a previous review.
125-154: LGTM: Type annotations are thorough and correct.The type annotations in
__init__are comprehensive, with proper union types for activation-related attributes and appropriate conditional types fornormanddropout.
156-163: LGTM: Return type annotation is correct.The
Nonereturn type is properly specified for this initialization method.src/visualizr/anitalker/liamodel.py (1)
11-17: Constructor typing looks good.Defaults and annotations for size/style/motion dims, channel_multiplier, blur_kernel, and fusion_type are appropriate.
src/visualizr/anitalker/config_base.py (2)
20-27: Typing and set logic look good.inherit/propagate annotations and common_keys usage are clear and correct.
33-45: from_dict strict/lenient handling is reasonable.ValueError in strict mode and warning otherwise matches typical expectations.
src/visualizr/anitalker/networks/styledecoder.py (1)
724-732: No action needed—code is compatible with project requirements.The project declares
requires-python = "==3.10.*", which fully supports thestrictparameter inzip(). The code is correct as-is.src/visualizr/anitalker/templates.py (3)
5-27: LGTM on return typing for template base.Clear
-> TrainConfigreturn improves callers and static checks. No behavior change.
30-41: LGTM: typed conf and explicit return.The local
conf: TrainConfigand explicit return are clean and consistent.
44-53: LGTM: ffhq256 template typing.Return typing and config construction look correct.
src/visualizr/anitalker/utils.py (1)
133-161: InferenceType is a Literal type alias (string-based), not an Enum—all current usages are safe.Verification confirms InferenceType is defined as
Literal["mfcc_full_control", "mfcc_pose_only", ...]in src/visualizr/app/types.py, not an Enum. Runtime values are strings, so both.startswith()calls in builder.py and seq2seq.py and thematchstatement here will work correctly. No changes needed.src/visualizr/anitalker/networks/encoder.py (1)
279-284: No changes needed—project requires Python 3.10The code is correct as-is. The
pyproject.tomlexplicitly declaresrequires-python = "==3.10.*"with classifiers listing only Python 3.10 support. Sincezip(strict=False)was introduced in Python 3.10 and the project exclusively targets 3.10, there is no runtime compatibility issue.Likely an incorrect or invalid review comment.
src/visualizr/anitalker/face_sr/face_enhancer.py (2)
65-89: LGTM! Return type annotation is accurate.The return type
tuple[int, str, str, str]correctly matches the values returned at line 89 (channel_multiplier, model_name, url, arch).
151-162: LGTM! Local variable type annotations are correct.The type annotations for
_msgand_bg_upsamplerare accurate. While somewhat redundant in cases where the type is obvious from the right-hand side (like the f-string), they contribute to overall type clarity.
| @lru_cache | ||
| def jsonable(x: Any) -> bool: | ||
| def jsonable(x) -> bool: | ||
| """Check if the object x is JSON serializable.""" | ||
| try: | ||
| dumps(x) | ||
| return True | ||
| except TypeError: | ||
| return False | ||
| return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jsonable breaks with @lru_cache on unhashable inputs and misses OverflowError.
- lru_cache requires hashable args; passing dict/list/set will raise TypeError before dumps.
- json.dumps may raise OverflowError in addition to TypeError.
Apply:
-@lru_cache
-def jsonable(x) -> bool:
+def jsonable(x: Any) -> bool:
"""Check if the object x is JSON serializable."""
try:
dumps(x)
- except TypeError:
+ except (TypeError, OverflowError):
return False
return TrueCommittable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/visualizr/anitalker/config_base.py around lines 62 to 69, the jsonable
function is decorated with @lru_cache which fails for unhashable inputs
(dict/list/set) and it only catches TypeError but json.dumps can also raise
OverflowError; remove the @lru_cache decorator so unhashable inputs are
accepted, and update the exception handling to catch both TypeError and
OverflowError from dumps and return False in those cases, otherwise return True.
| from typing import Literal, Self | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Fix Optional typing for fields that default to None.
These fields are annotated as non-Optional but initialized with None, which breaks strict type checking.
-from typing import Literal, Self
+from typing import Literal, Self
@@
- infer_type: InferenceType = None
+ infer_type: InferenceType | None = None
@@
- model_conf: ModelConfig = None
+ model_conf: ModelConfig | None = None
@@
- decoder_layers: int = None
- motion_dim: int = None
- mfcc: bool = None
- face_scale: bool = None
- face_location: bool = None
+ decoder_layers: int | None = None
+ motion_dim: int | None = None
+ mfcc: bool | None = None
+ face_scale: bool | None = None
+ face_location: bool | None = NoneAlso, confirm project Python version; typing.Self requires Python 3.11+. If supporting 3.10, replace with -> "TrainConfig".
Also applies to: 30-33, 50-53, 92-97
---
🏁 Script executed:
```shell
#!/bin/bash
# Verify the config.py file content and field type annotations
cat src/visualizr/anitalker/config.py
Length of output: 9377
🏁 Script executed:
#!/bin/bash
# Check Python version requirement in project config
fd -e pyproject.toml -e setup.py -e setup.cfg | head -5 | xargs cat 2>/dev/nullLength of output: 0
🏁 Script executed:
#!/bin/bash
# Check if there's a Python version constraint in README or .python-version
fd -e README.md -e .python-version | xargs cat 2>/dev/null | grep -i pythonLength of output: 0
🏁 Script executed:
#!/bin/bash
# Search for pyproject.toml or setup.py anywhere in repo
find . -maxdepth 3 -type f \( -name "pyproject.toml" -o -name "setup.py" -o -name "setup.cfg" \) 2>/dev/null | head -10Length of output: 16
🏁 Script executed:
#!/bin/bash
# Check common root files for Python version info
ls -la | grep -E "^-.*\.(toml|py|cfg|txt|yml|yaml)$"Length of output: 324
🏁 Script executed:
#!/bin/bash
# Attempt to find any Python version constraints via broader search
rg -i "python.*3\.[0-9]+" -B2 -A2 | head -30Length of output: 1510
Import Self from typing_extensions or use string literal for Python 3.10 compatibility; add Optional types to 7 fields defaulting to None.
The project requires Python 3.10 (requires-python = "==3.10.*"), but typing.Self was introduced in Python 3.11 (PEP 673). This will cause an ImportError at runtime. Additionally, seven fields are annotated as non-Optional but initialized with None, breaking strict type checking:
-from typing import Literal, Self
+from typing import Literal
+from typing_extensions import SelfThen update these field annotations:
- infer_type: InferenceType = None
+ infer_type: InferenceType | None = None
- model_conf: ModelConfig = None
+ model_conf: ModelConfig | None = None
- decoder_layers: int = None
- motion_dim: int = None
- mfcc: bool = None
- face_scale: bool = None
- face_location: bool = None
+ decoder_layers: int | None = None
+ motion_dim: int | None = None
+ mfcc: bool | None = None
+ face_scale: bool | None = None
+ face_location: bool | None = NoneAlso applies to any other occurrences in lines 30–33, 50–53, 92–97.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from typing import Literal, Self | |
| from typing import Literal | |
| from typing_extensions import Self | |
🤖 Prompt for AI Agents
In src/visualizr/anitalker/config.py around lines 2-3 (and also update
occurrences at lines 30–33, 50–53, 92–97): the file currently imports Self from
typing (Python 3.11-only) and declares seven fields as non-Optional while
defaulting them to None; replace the import by importing Self from
typing_extensions (or change annotations to the string literal "Self") to retain
Python 3.10 compatibility, and update the seven field type hints that are
assigned None to Optional[...] (e.g., change Type -> Optional[Type]) so their
annotations match their default values; apply the same Optional fixes to any
similar field annotations at the other listed line ranges.
| def space_timesteps( | ||
| num_timesteps: int, | ||
| section_counts: list[int] | str, | ||
| ) -> set[int | float]: | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Fix type annotations and a misleading error message in spaced timesteps.
space_timestepsreturns integer steps; align types toset[int].- Use
list[int]forall_steps/taken_steps. - The DDIM error should reference
desired_count, notnum_timesteps. - Prefer simple
floatforfrac_stride;Literal[1]adds noise. - Align
use_timestepstype with the actual producer and add a constructor guard.
-from typing import Literal
+from typing import Collection
@@
-def space_timesteps(
- num_timesteps: int,
- section_counts: list[int] | str,
-) -> set[int | float]:
+def space_timesteps(
+ num_timesteps: int,
+ section_counts: list[int] | str,
+) -> set[int]:
@@
- msg: str = (
- f"cannot create exactly {num_timesteps} steps with an integer stride"
- )
+ msg: str = f"cannot create exactly {desired_count} steps with an integer stride"
raise ValueError(msg)
@@
- start_idx: int = 0
- all_steps: list[float] = []
+ start_idx: int = 0
+ all_steps: list[int] = []
@@
- size: int = size_per + (1 if i < extra else 0)
+ size: int = size_per + (1 if i < extra else 0)
if size < section_count:
msg: str = f"cannot divide section of {size} steps into {section_count}"
raise ValueError(msg)
- frac_stride: float | Literal[1] = (
- 1 if section_count <= 1 else (size - 1) / (section_count - 1)
- )
- cur_idx: float = 0.0
- taken_steps: list[float] = []
+ frac_stride: float = 1.0 if section_count <= 1 else (size - 1) / (section_count - 1)
+ cur_idx: float = 0.0
+ taken_steps: list[int] = []
for _ in range(section_count):
- taken_steps.append(start_idx + round(cur_idx))
+ taken_steps.append(start_idx + int(round(cur_idx)))
cur_idx += frac_stride
all_steps += taken_steps
start_idx += size
- return set(all_steps)
+ return set(all_steps)
@@
- use_timesteps: tuple[int] | None = None
+ use_timesteps: Collection[int] | None = None
@@
- def __init__(self, conf: SpacedDiffusionBeatGansConfig) -> None:
+ def __init__(self, conf: SpacedDiffusionBeatGansConfig) -> None:
self.conf = conf
- self.use_timesteps: set[int] = set(conf.use_timesteps)
+ if not conf.use_timesteps:
+ raise ValueError("use_timesteps must be a non-empty collection of indices")
+ self.use_timesteps: set[int] = set(conf.use_timesteps)
# how the new t's mapped to the old t's
- self.timestep_map = []
+ self.timestep_map: list[int] = []Also applies to: 44-47, 51-63, 84-87, 93-96
| @abstractmethod | ||
| def weights(self): | ||
| def weights(self) -> None: | ||
| """ | ||
| Get a numpy array of weights, one per diffusion step. | ||
| The weights needn't be normalized but must be positive. | ||
| """ | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect return type for abstract weights(). Breaks the contract used by sample().
sample() expects a NumPy array; annotating -> None is wrong and will confuse type checkers.
-from abc import ABC, abstractmethod
+from abc import ABC, abstractmethod
+from numpy.typing import NDArray
@@
- @abstractmethod
- def weights(self) -> None:
+ @abstractmethod
+ def weights(self) -> NDArray[np.floating] | np.ndarray:
@@
- def sample(self, batch_size: int, device: th.device) -> tuple[th.Tensor, th.Tensor]:
+ def sample(self, batch_size: int, device: th.device) -> tuple[th.Tensor, th.Tensor]:
@@
class UniformSampler(ScheduleSampler):
- def __init__(self, num_timesteps) -> None:
+ def __init__(self, num_timesteps) -> None:
self._weights = np.ones([num_timesteps])
- def weights(self):
+ def weights(self) -> NDArray[np.floating] | np.ndarray:
return self._weightsOptionally, clamp/validate w to avoid zero-probability divisions when custom samplers override weights().
Also applies to: 26-43, 45-49
| if TYPE_CHECKING: | ||
| from visualizr.anitalker.diffusion.resample import UniformSampler | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Runtime NameError risk: annotation uses UniformSampler without importing it at runtime.
Variable annotations are evaluated at runtime unless from __future__ import annotations is enabled. Since UniformSampler is only imported under TYPE_CHECKING, self.T_sampler: UniformSampler = ... will raise NameError.
-if TYPE_CHECKING:
- from visualizr.anitalker.diffusion.resample import UniformSampler
+if TYPE_CHECKING:
+ from visualizr.anitalker.diffusion.resample import UniformSampler
@@
- self.T_sampler: UniformSampler = conf.make_t_sampler()
+ # Avoid runtime evaluation of the annotation without importing the type:
+ self.T_sampler: "UniformSampler" = conf.make_t_sampler()Alternatively, import UniformSampler unconditionally or enable from __future__ import annotations at module top.
Also applies to: 39-40
🏁 Script executed:
#!/bin/bash
# Check the experiment.py file content
cat src/visualizr/anitalker/experiment.pyLength of output: 2830
🏁 Script executed:
#!/bin/bash
# Search for all UniformSampler references in the codebase
rg -n "UniformSampler" --type=py -C 3Length of output: 2759
🏁 Script executed:
#!/bin/bash
# Check if __future__ annotations is imported
rg -n "from __future__ import" src/visualizr/anitalker/experiment.pyLength of output: 0
Fix runtime NameError on type annotation and undefined attribute in forward() method.
The review comment is accurate. Line 39 uses UniformSampler in a variable annotation without importing it at runtime (it's only imported under TYPE_CHECKING). Since the module lacks from __future__ import annotations, this annotation is evaluated at runtime and will raise NameError.
Additionally, the forward() method has ema_model: bool = False parameter that is never used, and instead references undefined self.disable_ema.
Fix line 39 by either:
- Quoting the annotation:
self.T_sampler: "UniformSampler" = conf.make_t_sampler() - Adding
from __future__ import annotationsat module top - Importing
UniformSamplerunconditionally
And fix the forward() method to either use the ema_model parameter or define self.disable_ema as an instance attribute.
🤖 Prompt for AI Agents
In src/visualizr/anitalker/experiment.py around lines 14 and 39 (and the forward
definition), a runtime NameError can occur because UniformSampler is only
imported under TYPE_CHECKING while annotations are evaluated at runtime, and
forward() uses an undefined self.disable_ema while accepting ema_model parameter
that is never used. Fix the annotation by either adding from __future__ import
annotations at the top of the module, importing UniformSampler unconditionally,
or quoting the annotation (e.g., "UniformSampler") so it isn't evaluated at
runtime; and fix forward() by either using the ema_model parameter (e.g.,
reference ema_model instead of self.disable_ema) or by initializing
self.disable_ema as an instance attribute (set in __init__) so the attribute
exists.
| if conf.num_heads_upsample == -1: | ||
| self.num_heads_upsample = conf.num_heads | ||
| self.dtype = th.float32 | ||
| self.time_emb_channels = conf.time_embed_channels or conf.model_channels | ||
| self.dtype: th.dtype = th.float32 | ||
| self.time_emb_channels: int = conf.time_embed_channels or conf.model_channels |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initialize num_heads_upsample in both branches
If conf.num_heads_upsample != -1, attribute remains undefined and later access crashes.
- if conf.num_heads_upsample == -1:
- self.num_heads_upsample = conf.num_heads
+ if conf.num_heads_upsample == -1:
+ self.num_heads_upsample = conf.num_heads
+ else:
+ self.num_heads_upsample = conf.num_heads_upsample📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if conf.num_heads_upsample == -1: | |
| self.num_heads_upsample = conf.num_heads | |
| self.dtype = th.float32 | |
| self.time_emb_channels = conf.time_embed_channels or conf.model_channels | |
| self.dtype: th.dtype = th.float32 | |
| self.time_emb_channels: int = conf.time_embed_channels or conf.model_channels | |
| if conf.num_heads_upsample == -1: | |
| self.num_heads_upsample = conf.num_heads | |
| else: | |
| self.num_heads_upsample = conf.num_heads_upsample | |
| self.dtype: th.dtype = th.float32 | |
| self.time_emb_channels: int = conf.time_embed_channels or conf.model_channels |
🤖 Prompt for AI Agents
In src/visualizr/anitalker/model/unet.py around lines 88 to 91,
self.num_heads_upsample is only set when conf.num_heads_upsample == -1 which
leaves the attribute undefined in the other branch; set self.num_heads_upsample
to conf.num_heads_upsample when it's not -1 so the attribute is always
initialized (i.e., assign self.num_heads_upsample = conf.num_heads when
conf.num_heads_upsample == -1 else self.num_heads_upsample =
conf.num_heads_upsample).
| @@ -1,41 +1,52 @@ | |||
| import math | |||
|
|
|||
| from regex import T | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Invalid import: from regex import T will fail at runtime.
The regex package does not export T. Remove this import. It’s unused.
-from regex import T📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from regex import T |
🤖 Prompt for AI Agents
In src/visualizr/anitalker/networks/discriminator.py around line 3, remove the
invalid and unused import "from regex import T" because the regex package does
not export T; simply delete that import line and ensure there are no references
to T elsewhere in the file (if any exist, replace or remove them accordingly).
| bias: bool = True, | ||
| bias_init: int = 0, | ||
| lr_mul: int = 1, | ||
| activation=None, | ||
| ): | ||
| ) -> None: | ||
| super().__init__() | ||
|
|
||
| self.weight = nn.Parameter(randn(out_dim, in_dim).div_(lr_mul)) | ||
| self.bias = nn.Parameter(zeros(out_dim).fill_(bias_init)) if bias else None | ||
| self.bias: nn.Parameter | None = ( | ||
| nn.Parameter(zeros(out_dim).fill_(bias_init)) if bias else None | ||
| ) | ||
| self.activation = activation | ||
| self.scale: float = (1 / sqrt(in_dim)) * lr_mul | ||
| self.lr_mul: int = lr_mul |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EqualLinear: wrong types and bias=None crash when activation is set
- bias_init and lr_mul should be float (docs and math expect float).
- If bias=False and activation is not None, fused_leaky_relu receives None and crashes.
- bias: bool = True,
- bias_init: int = 0,
- lr_mul: int = 1,
+ bias: bool = True,
+ bias_init: float = 0.0,
+ lr_mul: float = 1.0,
@@
- self.scale: float = (1 / sqrt(in_dim)) * lr_mul
- self.lr_mul: int = lr_mul
+ self.scale: float = (1 / sqrt(in_dim)) * lr_mul
+ self.lr_mul: float = lr_mul
@@
- if self.activation:
- out = linear(_input, self.weight * self.scale)
- out = fused_leaky_relu(out, self.bias * self.lr_mul)
+ if self.activation:
+ out = linear(
+ _input,
+ self.weight * self.scale,
+ bias=(self.bias * self.lr_mul) if self.bias is not None else None,
+ )
+ if self.bias is not None:
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
+ else:
+ # Preserve activation behavior without bias.
+ out = leaky_relu(out, negative_slope=0.2) * (2**0.5)
else:
- out = linear(_input, self.weight * self.scale, bias=self.bias * self.lr_mul)
+ out = linear(
+ _input,
+ self.weight * self.scale,
+ bias=(self.bias * self.lr_mul) if self.bias is not None else None,
+ )Also applies to: 189-195
🤖 Prompt for AI Agents
In src/visualizr/anitalker/networks/encoder.py around lines 175-187 (and also
apply the same change at 189-195): bias_init and lr_mul are typed and used as
floats, so change their annotations from int to float; also avoid passing None
to fused_leaky_relu by ensuring self.bias is a valid tensor when activation is
set — either create a zeros-initialized nn.Parameter for bias even when
bias=False if activation is provided, or in the forward path substitute a zeros
tensor for self.bias when calling the fused activation. Update the constructor
signatures to use float for bias_init and lr_mul, adjust any initialization math
accordingly, and add the conditional construction or forward-time fallback so
fused_leaky_relu never receives None.
| # image encoder | ||
| self.net_app = EncoderApp(size, dim, weighted_sum) | ||
|
|
||
| # decouping network | ||
| self.net_decouping = DecouplingModel(dim, dim, dim) | ||
| self.net_decouping: DecouplingModel = DecouplingModel(dim, dim, dim) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Align call site with normalized fusion_type
Pass a normalized value instead of a bare bool so EncoderApp won’t misinterpret it.
- self.net_app = EncoderApp(size, dim, weighted_sum)
+ self.net_app = EncoderApp(size, dim, "weighted_sum" if weighted_sum else None)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # image encoder | |
| self.net_app = EncoderApp(size, dim, weighted_sum) | |
| # decouping network | |
| self.net_decouping = DecouplingModel(dim, dim, dim) | |
| self.net_decouping: DecouplingModel = DecouplingModel(dim, dim, dim) | |
| # image encoder | |
| self.net_app = EncoderApp(size, dim, "weighted_sum" if weighted_sum else None) | |
| # decouping network | |
| self.net_decouping: DecouplingModel = DecouplingModel(dim, dim, dim) | |
🤖 Prompt for AI Agents
In src/visualizr/anitalker/networks/encoder.py around lines 401 to 406, the
EncoderApp is being constructed with a bare bool (weighted_sum) which can be
misinterpreted; convert/normalize that value to the expected fusion_type before
calling EncoderApp (e.g., derive the canonical fusion_type string/enum from the
boolean via the existing normalize_fusion_type or FusionType helper, then call
EncoderApp with fusion_type=<normalized_value> or the named parameter
fusion_type=...), so the app always receives a normalized fusion_type rather
than a raw bool.
| clips: list[ImageClip] = [ | ||
| ImageClip(m.as_posix()).set_duration(1 / fps) | ||
| for m in sorted(input_path.iterdir()) | ||
| ] | ||
| video = concatenate_videoclips(clips, "compose") | ||
| audio = AudioFileClip(audio_path) | ||
| final_video = video.set_audio(audio) | ||
| video: VideoClip | CompositeVideoClip = concatenate_videoclips(clips, "compose") | ||
| audio: AudioFileClip[Path] = AudioFileClip(audio_path) | ||
| final_video: VideoClip = video.set_audio(audio) | ||
| final_video.write_videofile( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enforce empty-frames check and filter inputs
Docstring promises FileNotFoundError but code never raises. Also filter to image files to avoid non-image artifacts.
- clips: list[ImageClip] = [
- ImageClip(m.as_posix()).set_duration(1 / fps)
- for m in sorted(input_path.iterdir())
- ]
+ frames = sorted(p for p in input_path.iterdir() if p.suffix.lower() in {".png", ".jpg", ".jpeg"} )
+ if not frames:
+ raise FileNotFoundError(f"No frames found in {input_path}")
+ clips: list[ImageClip] = [ImageClip(p.as_posix()).set_duration(1 / fps) for p in frames]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| clips: list[ImageClip] = [ | |
| ImageClip(m.as_posix()).set_duration(1 / fps) | |
| for m in sorted(input_path.iterdir()) | |
| ] | |
| video = concatenate_videoclips(clips, "compose") | |
| audio = AudioFileClip(audio_path) | |
| final_video = video.set_audio(audio) | |
| video: VideoClip | CompositeVideoClip = concatenate_videoclips(clips, "compose") | |
| audio: AudioFileClip[Path] = AudioFileClip(audio_path) | |
| final_video: VideoClip = video.set_audio(audio) | |
| final_video.write_videofile( | |
| frames = sorted(p for p in input_path.iterdir() if p.suffix.lower() in {".png", ".jpg", ".jpeg"} ) | |
| if not frames: | |
| raise FileNotFoundError(f"No frames found in {input_path}") | |
| clips: list[ImageClip] = [ImageClip(p.as_posix()).set_duration(1 / fps) for p in frames] | |
| video: VideoClip | CompositeVideoClip = concatenate_videoclips(clips, "compose") | |
| audio: AudioFileClip[Path] = AudioFileClip(audio_path) | |
| final_video: VideoClip = video.set_audio(audio) | |
| final_video.write_videofile( |


This pull request introduces type annotations and minor refactoring across several core files to improve code clarity, maintainability, and type safety. The most significant changes are the addition of type hints to class members and method signatures, updates to configuration and model-related logic, and some improvements to error handling and logging.
Type Annotations and Enum Refactoring
choices.py, ensuring consistent typing for configuration options such asTrainMode,ModelType,ModelName,ModelMeanType,ModelVarType,LossType,GenerativeType, andActivation. Also annotated methods likehas_autoencandget_actwith return types. [1] [2]Configuration and Model Logic Improvements
TrainConfiginconfig.pyto use the newInferenceTypeforinfer_type, added type annotations to fields and methods, and improved method signatures for configuration creation and scaling. Cleaned up logic for model configuration selection. [1] [2] [3] [4]BaseConfiginconfig_base.pyby adding type annotations to methods and variables, refining the inheritance and propagation logic, and updating logging for strict key checking in configuration loading.jsonableutility function to fix the logic so that it only returnsTrueif an object is JSON serializable.Diffusion Model Type Safety and Error Handling
diffusion/base.py, including sampler creation, model initialization, sampling loops, and helper methods. Improved error handling by using descriptive messages for exceptions. [1] [2] [3] [4] [5] [6] [7] [8] [9] [10] [11] [12] [13] [14] [15]These changes collectively enhance code readability and reliability, making it easier to maintain and extend the codebase.