diff --git a/src/visualizr/anitalker/choices.py b/src/visualizr/anitalker/choices.py index cbafc45e..4e02df71 100644 --- a/src/visualizr/anitalker/choices.py +++ b/src/visualizr/anitalker/choices.py @@ -5,35 +5,35 @@ class TrainMode(Enum): # manipulate mode = training the classifier - manipulate = "manipulate" + manipulate: str = "manipulate" # default training mode! - diffusion = "diffusion" + diffusion: str = "diffusion" class ModelType(Enum): """Kinds of the backbone models.""" # unconditional ddpm - ddpm = "ddpm" + ddpm: str = "ddpm" # autoencoding ddpm cannot do unconditional generation - autoencoder = "autoencoder" + autoencoder: str = "autoencoder" - def has_autoenc(self): + def has_autoenc(self) -> bool: return self in [ModelType.autoencoder] class ModelName(Enum): """List of all supported model classes.""" - beatgans_ddpm = "beatgans_ddpm" - beatgans_autoenc = "beatgans_autoenc" + beatgans_ddpm: str = "beatgans_ddpm" + beatgans_autoenc: str = "beatgans_autoenc" class ModelMeanType(Enum): """Which type of output the model predicts.""" # the model predicts epsilon - eps = "eps" + eps: str = "eps" class ModelVarType(Enum): @@ -45,29 +45,29 @@ class ModelVarType(Enum): """ # posterior beta_t - fixed_small = "fixed_small" + fixed_small: str = "fixed_small" # beta_t - fixed_large = "fixed_large" + fixed_large: str = "fixed_large" class LossType(Enum): # use raw MSE loss and KL when learning variances - mse = "mse" + mse: str = "mse" class GenerativeType(Enum): """where how a sample is generated.""" - ddpm = "ddpm" - ddim = "ddim" + ddpm: str = "ddpm" + ddim: str = "ddim" class Activation(Enum): - none = "none" - relu = "relu" - lrelu = "lrelu" - silu = "silu" - tanh = "tanh" + none: str = "none" + relu: str = "relu" + lrelu: str = "lrelu" + silu: str = "silu" + tanh: str = "tanh" def get_act(self) -> Identity | ReLU | LeakyReLU | SiLU | Tanh: match self: diff --git a/src/visualizr/anitalker/config.py b/src/visualizr/anitalker/config.py index d6170720..09113435 100644 --- a/src/visualizr/anitalker/config.py +++ b/src/visualizr/anitalker/config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Literal +from typing import Literal, Self from visualizr.anitalker.choices import ( Activation, @@ -24,17 +24,12 @@ ModelConfig, ) from visualizr.anitalker.model.latentnet import LatentNetType, MLPSkipNetConfig +from visualizr.app.types import InferenceType @dataclass class TrainConfig(BaseConfig): - infer_type: Literal[ - "mfcc_full_control", - "mfcc_pose_only", - "hubert_pose_only", - "hubert_audio_only", - "hubert_full_control", - ] = None + infer_type: InferenceType = None # random seed seed: int = 0 train_mode: TrainMode = TrainMode.diffusion @@ -94,19 +89,22 @@ class TrainConfig(BaseConfig): T: int = 1_000 # to be overridden name: str = "" - decoder_layers = None - motion_dim = None + decoder_layers: int = None + motion_dim: int = None + mfcc: bool = None + face_scale: bool = None + face_location: bool = None - def __post_init__(self): + def __post_init__(self) -> None: self.batch_size_eval = self.batch_size_eval or self.batch_size self.data_val_name = self.data_val_name or self.data_name - def scale_up_gpus(self, num_gpus, num_nodes=1): + def scale_up_gpus(self, num_gpus: int, num_nodes: int = 1) -> Self: self.batch_size *= num_gpus * num_nodes self.batch_size_eval *= num_gpus * num_nodes return self - def _make_diffusion_conf(self, t: int): + def _make_diffusion_conf(self, t: int) -> SpacedDiffusionBeatGansConfig: if self.diffusion_type != "beatgans": raise NotImplementedError # can use t < `self.t` for evaluation @@ -132,7 +130,7 @@ def _make_diffusion_conf(self, t: int): ) @property - def model_out_channels(self): + def model_out_channels(self) -> int: return 3 def make_t_sampler(self) -> UniformSampler: @@ -140,13 +138,13 @@ def make_t_sampler(self) -> UniformSampler: raise NotImplementedError return UniformSampler(self.T) - def make_diffusion_conf(self): + def make_diffusion_conf(self) -> SpacedDiffusionBeatGansConfig: return self._make_diffusion_conf(self.T) - def make_eval_diffusion_conf(self): + def make_eval_diffusion_conf(self) -> SpacedDiffusionBeatGansConfig: return self._make_diffusion_conf(self.T_eval) - def make_model_conf(self): + def make_model_conf(self) -> BeatGANsAutoencConfig | BeatGANsUNetConfig: if self.model_name == ModelName.beatgans_ddpm: self.model_type = ModelType.ddpm self.model_conf = BeatGANsUNetConfig( @@ -172,16 +170,9 @@ def make_model_conf(self): resnet_two_cond=self.net_beatgans_resnet_two_cond, resnet_use_zero_module=self.net_beatgans_resnet_use_zero_module, ) - elif self.model_name in [ - ModelName.beatgans_autoenc, - ]: + elif self.model_name == ModelName.beatgans_autoenc: cls = BeatGANsAutoencConfig - # supports both autoenc and vaeddpm - if self.model_name == ModelName.beatgans_autoenc: - self.model_type = ModelType.autoencoder - else: - raise NotImplementedError - + self.model_type = ModelType.autoencoder if self.net_latent_net_type == LatentNetType.none: latent_net_conf = None elif self.net_latent_net_type == LatentNetType.skip: diff --git a/src/visualizr/anitalker/config_base.py b/src/visualizr/anitalker/config_base.py index 55df6987..7c141b7a 100644 --- a/src/visualizr/anitalker/config_base.py +++ b/src/visualizr/anitalker/config_base.py @@ -17,20 +17,20 @@ class BaseConfig: and serialize/deserialize configurations to/from JSON. """ - def inherit(self, another): + def inherit(self, another) -> None: """Inherit common keys from a given config.""" - common_keys = set(self.__dict__.keys()) & set(another.__dict__.keys()) + common_keys: set[str] = set(self.__dict__.keys()) & set(another.__dict__.keys()) for k in common_keys: setattr(self, k, getattr(another, k)) - def propagate(self): + def propagate(self) -> None: """Push down the configuration to all members.""" for _, v in self.__dict__.items(): if isinstance(v, BaseConfig): v.inherit(self) v.propagate() - def from_dict(self, config_dict, strict=False): + def from_dict(self, config_dict: dict, strict: bool = False) -> None: """ Populate configuration attributes from a dictionary. @@ -38,20 +38,19 @@ def from_dict(self, config_dict, strict=False): """ for k, v in config_dict.items(): if not hasattr(self, k): + _msg: str = f"loading extra '{k}'" if strict: - raise ValueError(f"loading extra '{k}'") - _msg = f"loading extra '{k}'" - logger.info(_msg) - Info(_msg) + raise ValueError(_msg) + logger.warning(_msg) continue if isinstance(self.__dict__[k], BaseConfig): self.__dict__[k].from_dict(v) else: self.__dict__[k] = v - def as_dict_jsonable(self): + def as_dict_jsonable(self) -> dict: """Convert the configuration to a JSON-serializable dictionary.""" - conf = {} + conf: dict = {} for k, v in self.__dict__.items(): if isinstance(v, BaseConfig): conf[k] = v.as_dict_jsonable() @@ -61,10 +60,10 @@ def as_dict_jsonable(self): @lru_cache -def jsonable(x: Any) -> bool: +def jsonable(x) -> bool: """Check if the object x is JSON serializable.""" try: dumps(x) - return True except TypeError: return False + return True diff --git a/src/visualizr/anitalker/diffusion/base.py b/src/visualizr/anitalker/diffusion/base.py index 5ad232c6..8ded8ffd 100644 --- a/src/visualizr/anitalker/diffusion/base.py +++ b/src/visualizr/anitalker/diffusion/base.py @@ -45,7 +45,7 @@ class GaussianDiffusionBeatGansConfig(BaseConfig): rescale_timesteps: bool fp16: bool - def make_sampler(self): + def make_sampler(self) -> "GaussianDiffusionBeatGans": """ Create a `GaussianDiffusionBeatGans` sampler based on this configuration. @@ -59,20 +59,22 @@ def make_sampler(self): class GaussianDiffusionBeatGans: """Utilities for training and sampling diffusion models.""" - def __init__(self, conf: GaussianDiffusionBeatGansConfig): - self.conf = conf - self.model_mean_type = conf.model_mean_type - self.model_var_type = conf.model_var_type - self.loss_type = conf.loss_type - self.rescale_timesteps = conf.rescale_timesteps + def __init__(self, conf: GaussianDiffusionBeatGansConfig) -> None: + self.conf: GaussianDiffusionBeatGansConfig = conf + self.model_mean_type: ModelMeanType = conf.model_mean_type + self.model_var_type: ModelVarType = conf.model_var_type + self.loss_type: LossType = conf.loss_type + self.rescale_timesteps: bool = conf.rescale_timesteps # Use float64 for accuracy. betas = np.array(conf.betas, dtype=np.float64) self.betas = betas if len(betas.shape) != 1: - raise ValueError("betas must be 1D") + msg = "betas must be 1D" + raise ValueError(msg) if not ((betas > 0).all() and (betas <= 1).all()): - raise ValueError("betas must be positive and less than or equal to 1") + msg = "betas must be positive and less than or equal to 1" + raise ValueError(msg) self.num_timesteps = int(betas.shape[0]) alphas = 1.0 - betas @@ -111,9 +113,9 @@ def sample( noise=None, cond=None, x_start=None, - clip_denoised=True, + clip_denoised: bool = True, model_kwargs=None, - progress=False, + progress: bool = False, ): """ Generate samples from the diffusion model using either DDPM or DDIM sampling. @@ -135,7 +137,7 @@ def sample( torch.Tensor: The generated samples from the model. """ if model_kwargs is None: - model_kwargs = {} + model_kwargs: dict = {} if self.conf.model_type.has_autoenc(): model_kwargs["x_start"] = x_start model_kwargs["cond"] = cond @@ -205,12 +207,12 @@ def q_posterior_mean_variance(self, x_start, x_t, t): def p_mean_variance( self, - model, + model: Model, x, t: th.Tensor, clip_denoised=True, denoised_fn=None, - model_kwargs=None, + model_kwargs: dict | None = None, ): """ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of @@ -319,7 +321,7 @@ def process_xstart(_x): "model_forward": model_forward, } - def _predict_xstart_from_eps(self, x_t, t, eps): + def _predict_xstart_from_eps(self, x_t, t: th.Tensor, eps): if x_t.shape != eps.shape: raise ValueError(f"Shape mismatch: {x_t.shape} vs {eps.shape}") return ( @@ -327,19 +329,21 @@ def _predict_xstart_from_eps(self, x_t, t, eps): - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) - def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + def _predict_eps_from_xstart(self, x_t, t: th.Tensor, pred_xstart): return ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) - def _scale_timesteps(self, t): + def _scale_timesteps(self, t: th.Tensor) -> th.Tensor: if self.rescale_timesteps: # scale t to be maxed out at 1000 steps return t.float() * (1000.0 / self.num_timesteps) return t - def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + def condition_mean( + self, cond_fn, p_mean_var, x, t: th.Tensor, model_kwargs: dict | None = None + ): """ Compute the mean for the previous step, given a function `cond_fn` that computes the gradient of a conditional log probability about @@ -351,7 +355,9 @@ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) return p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() - def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + def condition_score( + self, cond_fn, p_mean_var, x, t: th.Tensor, model_kwargs: dict | None = None + ): """ Compute what the p_mean_variance output would have been, should the model's score function be conditioned by `cond_fn`. @@ -414,8 +420,8 @@ def p_sample( denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) - noise = th.randn_like(x) - nonzero_mask = ( + noise: th.Tensor = th.randn_like(x) + nonzero_mask: th.Tensor = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 if cond_fn is not None: @@ -434,12 +440,12 @@ def p_sample_loop( model: Model, shape=None, noise=None, - clip_denoised=True, + clip_denoised: bool = True, denoised_fn=None, cond_fn=None, model_kwargs=None, - device=None, - progress=False, + device: th.device | None = None, + progress: bool = False, ): """ Generate samples from the model. @@ -480,12 +486,12 @@ def p_sample_loop_progressive( model: Model, shape=None, noise=None, - clip_denoised=True, + clip_denoised: bool = True, denoised_fn=None, cond_fn=None, - model_kwargs=None, - device=None, - progress=False, + model_kwargs: dict | None = None, + device: th.device | None = None, + progress: bool = False, ): """ Generate samples from the model and yield intermediate samples from @@ -502,15 +508,15 @@ def p_sample_loop_progressive( else: if not isinstance(shape, (tuple, list)): raise TypeError(f"Shape must be a tuple or list, not a {type(shape)}") - img = th.randn(*shape, device=device) - indices = list(range(self.num_timesteps))[::-1] + img: th.Tensor = th.randn(*shape, device=device) + indices: list[int] = list(range(self.num_timesteps))[::-1] if progress: # Lazy import so that we don't depend on tqdm. indices = tqdm(indices) for i in indices: - t = th.tensor([i] * len(img), device=device) + t: th.Tensor = th.tensor([i] * len(img), device=device) with th.no_grad(): out = self.p_sample( model, @@ -529,11 +535,11 @@ def ddim_sample( model: Model, x, t: th.Tensor, - clip_denoised=True, + clip_denoised: bool = True, denoised_fn=None, cond_fn=None, model_kwargs=None, - eta=0.0, + eta: float = 0.0, ): """ Sample `x_{t-1}` from the model using DDIM. @@ -557,18 +563,18 @@ def ddim_sample( alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) - sigma = ( + sigma: th.Tensor = ( eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. - noise = th.randn_like(x) + noise: th.Tensor = th.randn_like(x) mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps ) - nonzero_mask = ( + nonzero_mask: th.Tensor = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise @@ -579,13 +585,13 @@ def ddim_sample_loop( model: Model, shape=None, noise=None, - clip_denoised=True, + clip_denoised: bool = True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, - progress=False, - eta=0.0, + progress: bool = False, + eta: float = 0.0, ): """ Generate samples from the model using DDIM. @@ -613,13 +619,13 @@ def ddim_sample_loop_progressive( model: Model, shape=None, noise=None, - clip_denoised=True, + clip_denoised: bool = True, denoised_fn=None, cond_fn=None, - model_kwargs=None, - device=None, - progress=False, - eta=0.0, + model_kwargs: dict | None = None, + device: th.device | None = None, + progress: bool = False, + eta: float = 0.0, ): """ Use DDIM to sample from the model and yield intermediate samples from @@ -634,8 +640,8 @@ def ddim_sample_loop_progressive( else: if not isinstance(shape, (tuple, list)): raise TypeError(f"Shape must be a tuple or list, not a {type(shape)}") - img = th.randn(*shape, device=device) - indices = list(range(self.num_timesteps))[::-1] + img: th.Tensor = th.randn(*shape, device=device) + indices: list[int] = list(range(self.num_timesteps))[::-1] if progress: # Lazy import so that we don't depend on tqdm. @@ -645,7 +651,7 @@ def ddim_sample_loop_progressive( _kwargs = ( model_kwargs[i] if isinstance(model_kwargs, list) else model_kwargs ) - t = th.tensor([i] * len(img), device=device) + t: th.Tensor = th.tensor([i] * len(img), device=device) with th.no_grad(): out = self.ddim_sample( model, @@ -662,7 +668,9 @@ def ddim_sample_loop_progressive( img = out["sample"] -def _extract_into_tensor(arr, timesteps, broadcast_shape): +def _extract_into_tensor( + arr: np.ndarray, timesteps: th.Tensor, broadcast_shape +) -> th.Tensor: # TODO: np """ Extract values from a 1D numpy array for a batch of indexes. @@ -672,7 +680,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): dimension equal to the length of timesteps. :return: A tensor of shape [batch_size, 1, ...] Where the shape has K dims. """ - res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + res: th.Tensor = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() while len(res.shape) < len(broadcast_shape): res = res[..., None] return res.expand(broadcast_shape) @@ -751,7 +759,8 @@ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): dtype=np.float64, ) case _: - raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + msg: str = f"unknown beta schedule: {schedule_name}" + raise NotImplementedError(msg) def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): diff --git a/src/visualizr/anitalker/diffusion/diffusion.py b/src/visualizr/anitalker/diffusion/diffusion.py index 27a00636..8510de40 100644 --- a/src/visualizr/anitalker/diffusion/diffusion.py +++ b/src/visualizr/anitalker/diffusion/diffusion.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Literal import numpy as np @@ -8,7 +9,10 @@ ) -def space_timesteps(num_timesteps, section_counts): +def space_timesteps( + num_timesteps: int, + section_counts: list[int] | str, +) -> set[int | float]: """ Create a list of timesteps to use from an original diffusion process. @@ -37,21 +41,25 @@ def space_timesteps(num_timesteps, section_counts): for i in range(1, num_timesteps): if len(range(0, num_timesteps, i)) == desired_count: return set(range(0, num_timesteps, i)) - msg = f"cannot create exactly {num_timesteps} steps with an integer stride" + msg: str = ( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) raise ValueError(msg) section_counts = [int(x) for x in section_counts.split(",")] size_per = num_timesteps // len(section_counts) extra = num_timesteps % len(section_counts) - start_idx = 0 - all_steps = [] + start_idx: int = 0 + all_steps: list[float] = [] for i, section_count in enumerate(section_counts): - size = size_per + (1 if i < extra else 0) + size: int = size_per + (1 if i < extra else 0) if size < section_count: - msg = f"cannot divide section of {size} steps into {section_count}" + msg: str = f"cannot divide section of {size} steps into {section_count}" raise ValueError(msg) - frac_stride = 1 if section_count <= 1 else (size - 1) / (section_count - 1) - cur_idx = 0.0 - taken_steps = [] + frac_stride: float | Literal[1] = ( + 1 if section_count <= 1 else (size - 1) / (section_count - 1) + ) + cur_idx: float = 0.0 + taken_steps: list[float] = [] for _ in range(section_count): taken_steps.append(start_idx + round(cur_idx)) cur_idx += frac_stride @@ -65,7 +73,8 @@ class SpacedDiffusionBeatGansConfig(GaussianDiffusionBeatGansConfig): """ Configuration for a spaced diffusion process. - This class holds the parameters for creating a spaced diffusion sampler, including the timesteps to use. + This class holds the parameters for creating a spaced diffusion sampler, including + the timesteps to use. Args: use_timesteps: A collection (sequence or set) of timesteps from the @@ -74,16 +83,16 @@ class SpacedDiffusionBeatGansConfig(GaussianDiffusionBeatGansConfig): use_timesteps: tuple[int] | None = None - def make_sampler(self): + def make_sampler(self) -> "SpacedDiffusionBeatGans": return SpacedDiffusionBeatGans(self) class SpacedDiffusionBeatGans(GaussianDiffusionBeatGans): """A diffusion process, which can skip steps in a base diffusion process.""" - def __init__(self, conf: SpacedDiffusionBeatGansConfig): + def __init__(self, conf: SpacedDiffusionBeatGansConfig) -> None: self.conf = conf - self.use_timesteps = set(conf.use_timesteps) + self.use_timesteps: set[int] = set(conf.use_timesteps) # how the new t's mapped to the old t's self.timestep_map = [] base_diffusion = GaussianDiffusionBeatGans(conf) diff --git a/src/visualizr/anitalker/diffusion/resample.py b/src/visualizr/anitalker/diffusion/resample.py index fab6400b..b0d3e2f1 100644 --- a/src/visualizr/anitalker/diffusion/resample.py +++ b/src/visualizr/anitalker/diffusion/resample.py @@ -16,14 +16,14 @@ class ScheduleSampler(ABC): """ @abstractmethod - def weights(self): + def weights(self) -> None: """ Get a numpy array of weights, one per diffusion step. The weights needn't be normalized but must be positive. """ - def sample(self, batch_size, device): + def sample(self, batch_size: int, device: th.device) -> tuple[th.Tensor, th.Tensor]: """ Importance-sample timesteps for a batch. @@ -36,14 +36,14 @@ def sample(self, batch_size, device): w = self.weights() p = w / np.sum(w) indices_np = np.random.choice(len(p), size=(batch_size,), p=p) - indices = th.from_numpy(indices_np).long().to(device) + indices: th.Tensor = th.from_numpy(indices_np).long().to(device) weights_np = 1 / (len(p) * p[indices_np]) - weights = th.from_numpy(weights_np).float().to(device) + weights: th.Tensor = th.from_numpy(weights_np).float().to(device) return indices, weights class UniformSampler(ScheduleSampler): - def __init__(self, num_timesteps): + def __init__(self, num_timesteps) -> None: self._weights = np.ones([num_timesteps]) def weights(self): diff --git a/src/visualizr/anitalker/experiment.py b/src/visualizr/anitalker/experiment.py index 8d092bea..6752ae42 100644 --- a/src/visualizr/anitalker/experiment.py +++ b/src/visualizr/anitalker/experiment.py @@ -1,36 +1,46 @@ -import copy +from copy import deepcopy +from typing import TYPE_CHECKING -import torch from pytorch_lightning import LightningModule, seed_everything +from torch import randn from torch.cuda import amp from visualizr.anitalker.choices import TrainMode from visualizr.anitalker.config import TrainConfig +from visualizr.anitalker.diffusion.diffusion import SpacedDiffusionBeatGans from visualizr.anitalker.model.seq2seq import DiffusionPredictor from visualizr.anitalker.renderer import render_condition +if TYPE_CHECKING: + from visualizr.anitalker.diffusion.resample import UniformSampler + class LitModel(LightningModule): - def __init__(self, conf: TrainConfig): + def __init__(self, conf: TrainConfig) -> None: super().__init__() if conf.train_mode == TrainMode.manipulate: - raise ValueError("`conf.train_mode` cannot be `manipulate`") + msg = "`conf.train_mode` cannot be `manipulate`" + raise ValueError(msg) if conf.seed is not None: seed_everything(conf.seed) self.save_hyperparameters(conf.as_dict_jsonable()) - self.conf = conf + self.conf: TrainConfig = conf self.model = DiffusionPredictor(conf) - self.ema_model = copy.deepcopy(self.model) - self.ema_model.requires_grad_(False) + self.ema_model: DiffusionPredictor = deepcopy(self.model) + self.ema_model.requires_grad_(requires_grad=False) self.ema_model.eval() - self.sampler = conf.make_diffusion_conf().make_sampler() - self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler() + self.sampler: SpacedDiffusionBeatGans = ( + conf.make_diffusion_conf().make_sampler() + ) + self.eval_sampler: SpacedDiffusionBeatGans = ( + conf.make_eval_diffusion_conf().make_sampler() + ) # this is shared for both model and latent - self.T_sampler = conf.make_t_sampler() + self.T_sampler: UniformSampler = conf.make_t_sampler() # initial variables for consistent sampling self.register_buffer( "x_T", - torch.randn( + randn( conf.sample_size, 3, conf.img_size, @@ -72,5 +82,7 @@ def render( def forward(self, noise=None, x_start=None, ema_model: bool = False): with amp.autocast(False): - model = self.model if self.disable_ema else self.ema_model + model: DiffusionPredictor = ( + self.model if self.disable_ema else self.ema_model + ) return self.eval_sampler.sample(model=model, noise=noise, x_start=x_start) diff --git a/src/visualizr/anitalker/face_sr/face_enhancer.py b/src/visualizr/anitalker/face_sr/face_enhancer.py index 2c2674e1..eae21c94 100644 --- a/src/visualizr/anitalker/face_sr/face_enhancer.py +++ b/src/visualizr/anitalker/face_sr/face_enhancer.py @@ -62,7 +62,7 @@ def enhancer_list( return list(gen) -def setup_gfpgan_restorer(method: str): +def setup_gfpgan_restorer(method: str) -> tuple[int, str, str, str]: channel_multiplier: int | None = None model_name: str | None = None url: str | None = None @@ -103,7 +103,7 @@ def setup_background_upsampler(bg_upsampler: str) -> RealESRGANer | None: grWarning(_msg) _bg_upsampler = None else: - model = RRDBNet(num_in_ch=3, num_out_ch=3, scale=2) + model: RRDBNet[int, int] = RRDBNet(num_in_ch=3, num_out_ch=3, scale=2) # need to set False in CPU mode _bg_upsampler = RealESRGANer( scale=2, @@ -148,7 +148,7 @@ def enhancer_generator_no_len( "Expected one of: gfpgan, RestoreFormer, codeformer." ) raise ValueError(msg) - _msg = f"face enhancer: {method}" + _msg: str = f"face enhancer: {method}" logger.info(_msg) Info(_msg) if not isinstance(images, list) and images.is_file(): @@ -159,7 +159,7 @@ def enhancer_generator_no_len( channel_multiplier, model_name, url, arch = setup_gfpgan_restorer(method) # Setup background upsampler - _bg_upsampler = setup_background_upsampler(bg_upsampler) + _bg_upsampler: RealESRGANer | None = setup_background_upsampler(bg_upsampler) # determine model paths model_path: Path = GFPGAN_WEIGHTS / f"{model_name}.pth" @@ -168,7 +168,7 @@ def enhancer_generator_no_len( # download pre-trained models from URL model_path: str = url - restorer = GFPGANer( + restorer: GFPGANer[str] = GFPGANer( model_path=model_path if isinstance(model_path, str) else model_path.as_posix(), arch=arch, channel_multiplier=channel_multiplier, diff --git a/src/visualizr/anitalker/face_sr/videoio.py b/src/visualizr/anitalker/face_sr/videoio.py index 0ce3672b..801214be 100644 --- a/src/visualizr/anitalker/face_sr/videoio.py +++ b/src/visualizr/anitalker/face_sr/videoio.py @@ -5,7 +5,7 @@ def load_video_to_cv2(input_path: str) -> list[np.ndarray]: video_stream = cv2.VideoCapture(input_path) full_frames: list[np.ndarray] = [] - while 1: + while True: still_reading, frame = video_stream.read() if not still_reading: video_stream.release() diff --git a/src/visualizr/anitalker/liamodel.py b/src/visualizr/anitalker/liamodel.py index 3f9aacd1..5c3308e8 100644 --- a/src/visualizr/anitalker/liamodel.py +++ b/src/visualizr/anitalker/liamodel.py @@ -8,13 +8,13 @@ class LiaModel(nn.Module): def __init__( self, - size=256, - style_dim=512, - motion_dim=20, - channel_multiplier=1, - blur_kernel: list = None, - fusion_type="", - ): + size: int = 256, + style_dim: int = 512, + motion_dim: int = 20, + channel_multiplier: int = 1, + blur_kernel: list | None = None, + fusion_type: str = "", + ) -> None: if blur_kernel is None: blur_kernel = [1, 3, 3, 1] super().__init__() @@ -46,7 +46,7 @@ def load_lightning_model(self, lia_pretrained_model_path): if self_state[name].size() != state[orig_name].size(): Error( f"Wrong parameter length: {orig_name}, " - + f"model: {self_state[name].size()}, " + f"model: {self_state[name].size()}, " + f"loaded: {state[orig_name].size()}" ) continue diff --git a/src/visualizr/anitalker/model/blocks.py b/src/visualizr/anitalker/model/blocks.py index c57f18e9..8f2bd303 100644 --- a/src/visualizr/anitalker/model/blocks.py +++ b/src/visualizr/anitalker/model/blocks.py @@ -2,13 +2,16 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from numbers import Number +from typing import Literal import torch as th -from torch import nn +from torch import Tensor, nn +from torch.ao.nn.quantized.dynamic.modules.conv import Conv2d, Conv3d from torch.nn.functional import interpolate from visualizr.anitalker.config_base import BaseConfig from visualizr.anitalker.model.nn import ( + GroupNorm32, avg_pool_nd, conv_nd, normalization, @@ -21,7 +24,7 @@ class TimestepBlock(nn.Module, ABC): """Any module where forward() takes timestep embeddings as a second argument.""" @abstractmethod - def forward(self, x, emb=None, cond=None, lateral=None): + def forward(self, x, emb=None, cond=None, lateral=None) -> None: """Apply the module to `x` given `emb` timestep embeddings.""" @@ -65,11 +68,11 @@ class ResBlockConfig(BaseConfig): # this is defaulted from BeatGANs and seems to help learning. use_zero_module: bool = True - def __post_init__(self): + def __post_init__(self) -> None: self.out_channels = self.out_channels or self.channels self.cond_emb_channels = self.cond_emb_channels or self.emb_channels - def make_model(self): + def make_model(self) -> "ResBlock": return ResBlock(self) @@ -90,21 +93,21 @@ class ResBlock(TimestepBlock): """ - def __init__(self, conf: ResBlockConfig): + def __init__(self, conf: ResBlockConfig) -> None: super().__init__() - self.conf = conf + self.conf: ResBlockConfig = conf ############################# # IN LAYERS ############################# - layers = [ + layers: list[nn.Module] = [ normalization(conf.channels), nn.SiLU(), conv_nd(conf.dims, conf.channels, conf.out_channels, 3, padding=1), ] self.in_layers = nn.Sequential(*layers) - self.updown = conf.up or conf.down + self.updown: bool = conf.up or conf.down if conf.up: self.h_upd = Upsample(conf.channels, False, conf.dims) @@ -134,7 +137,7 @@ def __init__(self, conf: ResBlockConfig): # OUT LAYERS (ignored when there is no condition) ############################# # original version - conv = conv_nd( + conv: nn.Conv1d | Conv2d | Conv3d = conv_nd( conf.dims, conf.out_channels, conf.out_channels, @@ -175,7 +178,7 @@ def __init__(self, conf: ResBlockConfig): kernel_size = 1 padding = 0 - self.skip_connection = conv_nd( + self.skip_connection: nn.Conv1d | Conv2d | Conv3d = conv_nd( conf.dims, conf.channels, conf.out_channels, @@ -210,8 +213,9 @@ def _forward( # lateral may be supplied even if it doesn't require # the model will take the lateral only if `has_lateral` if lateral is None: - raise ValueError("`lateral` is required") - x = th.cat([x, lateral], dim=1) + msg = "`lateral` is required" + raise ValueError(msg) + x: th.Tensor = th.cat([x, lateral], dim=1) if self.updown: in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] @@ -256,10 +260,10 @@ def apply_conditions( h, emb=None, cond=None, - layers: nn.Sequential = None, + layers: nn.Sequential | None = None, scale_bias: float = 1, in_channels: int = 512, - up_down_layer: nn.Module = None, + up_down_layer: nn.Module | None = None, ): """ Apply conditions on the feature maps. @@ -268,7 +272,7 @@ def apply_conditions( emb: time conditional (ready to scale + shift) cond: encoder's conditional (read to scale + shift) """ - two_cond = emb is not None and cond is not None + two_cond: bool = emb is not None and cond is not None if emb is not None: # adjusting shapes @@ -336,14 +340,16 @@ class Upsample(nn.Module): upsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None): + def __init__(self, channels, use_conv, dims: int = 2, out_channels=None) -> None: super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv - self.dims = dims + self.dims: int = dims if use_conv: - self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + self.conv: nn.Conv1d | Conv2d | Conv3d = conv_nd( + dims, self.channels, self.out_channels, 3, padding=1 + ) def forward(self, x): if x.shape[1] != self.channels: @@ -368,15 +374,15 @@ class Downsample(nn.Module): downsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None): + def __init__(self, channels, use_conv, dims: int = 2, out_channels=None) -> None: super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv - self.dims = dims - stride = 2 if dims != 3 else (1, 2, 2) + self.dims: int = dims + stride: tuple[int] | Literal[2] = 2 if dims != 3 else (1, 2, 2) if use_conv: - self.op = conv_nd( + self.op: nn.Conv1d | Conv2d | Conv3d = conv_nd( dims, self.channels, self.out_channels, @@ -395,7 +401,7 @@ def __init__(self, channels, use_conv, dims=2, out_channels=None): def forward(self, x): if x.shape[1] != self.channels: - msg = f"Input has {x.shape[1]} channels but layer has {self.channels}" + msg: str = f"Input has {x.shape[1]} channels but layer has {self.channels}" raise ValueError(msg) return self.op(x) @@ -406,11 +412,11 @@ class AttentionBlock(nn.Module): def __init__( self, channels, - num_heads=1, - num_head_channels=-1, - use_checkpoint=False, - use_new_attention_order=False, - ): + num_heads: int = 1, + num_head_channels: int = -1, + use_checkpoint: bool = False, + use_new_attention_order: bool = False, + ) -> None: super().__init__() self.channels = channels if num_head_channels == -1: @@ -423,9 +429,9 @@ def __init__( f"divisible by `num_head_channels` {num_head_channels}" ) raise ValueError(msg) - self.use_checkpoint = use_checkpoint - self.norm = normalization(channels) - self.qkv = conv_nd(1, channels, channels * 3, 1) + self.use_checkpoint: bool = use_checkpoint + self.norm: GroupNorm32 = normalization(channels) + self.qkv: nn.Conv1d | Conv2d | Conv3d = conv_nd(1, channels, channels * 3, 1) if use_new_attention_order: # split qkv before split heads self.attention = QKVAttention(self.num_heads) @@ -433,7 +439,9 @@ def __init__( # split heads before split qkv self.attention = QKVAttentionLegacy(self.num_heads) - self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + self.proj_out: nn.Conv1d | Conv2d | Conv3d = zero_module( + conv_nd(1, channels, channels, 1) + ) def forward(self, x): return torch_checkpoint(self._forward, (x,), self.use_checkpoint) @@ -454,11 +462,11 @@ class QKVAttentionLegacy(nn.Module): Matches legacy QKVAttention + input/output heads shaping. """ - def __init__(self, n_heads): + def __init__(self, n_heads) -> None: super().__init__() self.n_heads = n_heads - def forward(self, qkv): + def forward(self, qkv: Tensor) -> Tensor: """ Apply QKV attention. @@ -467,42 +475,46 @@ def forward(self, qkv): """ bs, width, length = qkv.shape if width % (3 * self.n_heads) != 0: - raise ValueError(f"Invalid qkv shape {qkv.shape}") + msg: str = f"Invalid qkv shape {qkv.shape}" + raise ValueError(msg) ch = width // (3 * self.n_heads) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) + scale: float = 1 / math.sqrt(math.sqrt(ch)) # More stable with f16 than dividing afterward - weight = th.einsum("bct,bcs->bts", q * scale, k * scale) + weight: Tensor = th.einsum("bct,bcs->bts", q * scale, k * scale) weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v) + a: Tensor = th.einsum("bts,bcs->bct", weight, v) return a.reshape(bs, -1, length) class QKVAttention(nn.Module): """A module, which performs QKV attention and splits in a different order.""" - def __init__(self, n_heads): + def __init__(self, n_heads) -> None: super().__init__() self.n_heads = n_heads - def forward(self, qkv): + def forward(self, qkv: Tensor) -> Tensor: """ Apply QKV attention. - :param qkv: An `[N x (3 × H × C) x T]` tensor of Qs, Ks, and Vs. - :return: An `[N x (H × C) x T]` tensor after attention. + :param qkv: An `[N x (3 x H x C) x T]` tensor of Qs, Ks, and Vs. + :return: An `[N x (H x C) x T]` tensor after attention. """ bs, width, length = qkv.shape if width % (3 * self.n_heads) != 0: - raise ValueError(f"Invalid qkv shape {qkv.shape}") - ch = width // (3 * self.n_heads) + msg: str = f"Invalid qkv shape {qkv.shape}" + raise ValueError(msg) + ch: int = width // (3 * self.n_heads) q, k, v = qkv.chunk(3, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( + scale: float = 1 / math.sqrt(math.sqrt(ch)) + weight: Tensor = th.einsum( "bct,bcs->bts", (q * scale).view(bs * self.n_heads, ch, length), (k * scale).view(bs * self.n_heads, ch, length), ) # More stable with f16 than dividing afterward weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + a: Tensor = th.einsum( + "bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length) + ) return a.reshape(bs, -1, length) diff --git a/src/visualizr/anitalker/model/latentnet.py b/src/visualizr/anitalker/model/latentnet.py index a1de7031..b9997353 100644 --- a/src/visualizr/anitalker/model/latentnet.py +++ b/src/visualizr/anitalker/model/latentnet.py @@ -12,13 +12,13 @@ class LatentNetType(Enum): - none = "none" + none: str = "none" # injecting inputs into the hidden layers - skip = "skip" + skip: str = "skip" class LatentNetReturn(NamedTuple): - pred: torch.Tensor = None + pred: torch.Tensor | None = None @dataclass @@ -38,7 +38,7 @@ class MLPSkipNetConfig(BaseConfig): num_time_layers: int = 2 time_last_act: bool = False - def make_model(self): + def make_model(self) -> "MLPSkipNet": return MLPSkipNet(self) @@ -49,20 +49,26 @@ class MLPSkipNet(nn.Module): Default MLP for the latent DPM in the paper. """ - def __init__(self, conf: MLPSkipNetConfig): + def __init__(self, conf: MLPSkipNetConfig) -> None: super().__init__() - self.conf = conf + self.conf: MLPSkipNetConfig = conf - layers = [] + layers: list[nn.Module] = [] for i in range(conf.num_time_layers): - a = conf.num_time_emb_channels if i == 0 else conf.num_channels - b = conf.num_channels + a: int = conf.num_time_emb_channels if i == 0 else conf.num_channels + b: int = conf.num_channels layers.append(nn.Linear(a, b)) if i < conf.num_time_layers - 1 or conf.time_last_act: layers.append(conf.activation.get_act()) - self.time_embed = nn.Sequential(*layers) - - self.layers = nn.ModuleList([]) + self.time_embed: nn.Sequential = nn.Sequential(*layers) + self.layers: nn.ModuleList = nn.ModuleList([]) + + act: Activation | None = None + norm: bool | None = None + cond: bool | None = None + a: int | None = None + b: int | None = None + dropout: float | None = None for i in range(conf.num_layers): if i == 0: act = conf.activation @@ -98,10 +104,12 @@ def __init__(self, conf: MLPSkipNetConfig): dropout=dropout, ), ) - self.last_act = conf.last_act.get_act() + self.last_act: nn.Identity | nn.ReLU | nn.LeakyReLU | nn.SiLU | nn.Tanh = ( + conf.last_act.get_act() + ) - def forward(self, x, t): - t = timestep_embedding(t, self.conf.num_time_emb_channels) + def forward(self, x, t) -> LatentNetReturn: + t: torch.Tensor = timestep_embedding(t, self.conf.num_time_emb_channels) cond = self.time_embed(t) h = x for i in range(len(self.layers)): @@ -124,22 +132,28 @@ def __init__( cond_channels: int, condition_bias: float = 0, dropout: float = 0, - ): + ) -> None: super().__init__() - self.activation = activation - self.condition_bias = condition_bias - self.use_cond = use_cond + self.activation: Activation = activation + self.condition_bias: float = condition_bias + self.use_cond: bool = use_cond self.linear = nn.Linear(in_channels, out_channels) - self.act = activation.get_act() + self.act: nn.Identity | nn.ReLU | nn.LeakyReLU | nn.SiLU | nn.Tanh = ( + activation.get_act() + ) if self.use_cond: self.linear_emb = nn.Linear(cond_channels, out_channels) self.cond_layers = nn.Sequential(self.act, self.linear_emb) - self.norm = nn.LayerNorm(out_channels) if norm else nn.Identity() - self.dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity() + self.norm: nn.LayerNorm | nn.Identity = ( + nn.LayerNorm(out_channels) if norm else nn.Identity() + ) + self.dropout: nn.Dropout | nn.Identity = ( + nn.Dropout(p=dropout) if dropout > 0 else nn.Identity() + ) self.init_weights() - def init_weights(self): + def init_weights(self) -> None: for module in self.modules(): if isinstance(module, nn.Linear): if self.activation in [Activation.relu, Activation.silu]: diff --git a/src/visualizr/anitalker/model/nn.py b/src/visualizr/anitalker/model/nn.py index c4626e84..a8836ceb 100644 --- a/src/visualizr/anitalker/model/nn.py +++ b/src/visualizr/anitalker/model/nn.py @@ -1,6 +1,7 @@ from math import log from torch import ( + Tensor, arange, cat, cos, @@ -14,11 +15,11 @@ class GroupNorm32(nn.GroupNorm): - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: return super().forward(x.float()).type(x.dtype) -def conv_nd(dims, *args, **kwargs): +def conv_nd(dims, *args, **kwargs) -> nn.Conv1d | nn.Conv2d | nn.Conv3d: """Create a 1D, 2D, or 3D convolution module.""" match dims: case 1: @@ -27,11 +28,13 @@ def conv_nd(dims, *args, **kwargs): return nn.Conv2d(*args, **kwargs) case 3: return nn.Conv3d(*args, **kwargs) - msg = f"unsupported dimensions: {dims}" + msg: str = f"unsupported dimensions: {dims}" raise ValueError(msg) -def avg_pool_nd(dims, *args, **kwargs): +def avg_pool_nd( + dims: int, *args, **kwargs +) -> nn.AvgPool1d | nn.AvgPool2d | nn.AvgPool3d: """Create a 1D, 2D, or 3D average pooling module.""" match dims: case 1: @@ -40,7 +43,7 @@ def avg_pool_nd(dims, *args, **kwargs): return nn.AvgPool2d(*args, **kwargs) case 3: return nn.AvgPool3d(*args, **kwargs) - msg = f"unsupported dimensions: {dims}" + msg: str = f"unsupported dimensions: {dims}" raise ValueError(msg) @@ -51,7 +54,7 @@ def zero_module(module): return module -def normalization(channels): +def normalization(channels) -> GroupNorm32: """ Make a standard normalization layer. @@ -61,7 +64,7 @@ def normalization(channels): return GroupNorm32(min(32, channels), channels) -def timestep_embedding(timesteps, dim, max_period=10000): +def timestep_embedding(timesteps, dim, max_period: int = 10000) -> Tensor: """ Create sinusoidal timestep embeddings. @@ -72,17 +75,19 @@ def timestep_embedding(timesteps, dim, max_period=10000): :return: An [N x dim] Tensor of positional embeddings. """ half = dim // 2 - freqs = exp(-log(max_period) * arange(start=0, end=half, dtype=float32) / half).to( + freqs: Tensor = exp( + -log(max_period) * arange(start=0, end=half, dtype=float32) / half + ).to( device=timesteps.device, ) args = timesteps[:, None].float() * freqs[None] - embedding = cat([cos(args), sin(args)], dim=-1) + embedding: Tensor = cat([cos(args), sin(args)], dim=-1) if dim % 2: embedding = cat([embedding, zeros_like(embedding[:, :1])], dim=-1) return embedding -def torch_checkpoint(func, args, flag, preserve_rng_state=False): +def torch_checkpoint(func, args, flag, preserve_rng_state: bool = False): # torch's gradient checkpoint works with automatic mixed precision, given `torch>=1.8` if flag: return checkpoint(func, *args, preserve_rng_state=preserve_rng_state) diff --git a/src/visualizr/anitalker/model/seq2seq.py b/src/visualizr/anitalker/model/seq2seq.py index c2216d34..5b7f4295 100644 --- a/src/visualizr/anitalker/model/seq2seq.py +++ b/src/visualizr/anitalker/model/seq2seq.py @@ -1,6 +1,6 @@ from espnet.nets.pytorch_backend.conformer.encoder import Encoder from gradio import Error, Info -from torch import cat, nn, zeros +from torch import Tensor, cat, nn, zeros from torch.nn import Module from torch.nn.functional import softmax @@ -9,7 +9,13 @@ class LSTM(Module): - def __init__(self, motion_dim, output_dim, num_layers=2, hidden_dim=128): + def __init__( + self, + motion_dim, + output_dim, + num_layers: int = 2, + hidden_dim: int = 128, + ) -> None: super().__init__() self.lstm = nn.LSTM( input_size=motion_dim, @@ -35,7 +41,7 @@ def __init__( decoder_dim: int = 1024, motion_start_dim: int = 512, hal_layers: int = 25, - ): + ) -> None: super().__init__() self.conf: TrainConfig = conf # Speech downsampling @@ -66,7 +72,7 @@ def __init__( ) self.weights = nn.Parameter(zeros(hal_layers)) - self.speech_encoder = self.create_conformer_encoder( + self.speech_encoder: Encoder = self.create_conformer_encoder( speech_dim, speech_layers, ) @@ -80,7 +86,7 @@ def __init__( Error(_msg) raise ValueError(_msg) # Encoders & Decoders - self.coarse_decoder = self.create_conformer_encoder( + self.coarse_decoder: Encoder = self.create_conformer_encoder( decoder_dim, conf.decoder_layers, ) @@ -137,7 +143,7 @@ def forward( yaw_pitch_roll, noisy_x, t_emb, - control_flag=False, + control_flag: bool = False, ): x = None if self.conf.infer_type.startswith("mfcc"): @@ -187,7 +193,7 @@ def adjust_features( face_location, face_scale, yaw_pitch_roll, - control_flag, + control_flag: bool, ): predicted_location, predicted_scale = 0, 0 if "full_control" in self.conf.infer_type: @@ -220,7 +226,9 @@ def adjust_pose(self, x, yaw_pitch_roll, control_flag): predicted_pose = yaw_pitch_roll if control_flag else self.pose_predictor(x) return self.pose_encoder(predicted_pose), predicted_pose - def combine_features(self, x, initial_code, direction_code, noisy_x, t_emb): + def combine_features( + self, x, initial_code, direction_code, noisy_x, t_emb + ) -> Tensor: init_code_proj = ( self.init_code_proj(initial_code).unsqueeze(1).repeat(1, x.size(1), 1) ) @@ -240,6 +248,6 @@ def combine_features(self, x, initial_code, direction_code, noisy_x, t_emb): dim=-1, ) - def decode_features(self, concatenated_features): + def decode_features(self, concatenated_features: Tensor): outputs, _ = self.coarse_decoder(concatenated_features, masks=None) return self.out_proj(outputs) diff --git a/src/visualizr/anitalker/model/unet.py b/src/visualizr/anitalker/model/unet.py index d20dbf5a..1ad9eca7 100644 --- a/src/visualizr/anitalker/model/unet.py +++ b/src/visualizr/anitalker/model/unet.py @@ -8,6 +8,7 @@ from visualizr.anitalker.model.blocks import ( AttentionBlock, Downsample, + ResBlock, ResBlockConfig, TimestepEmbedSequential, Upsample, @@ -27,23 +28,23 @@ class BeatGANsUNetConfig(BaseConfig): # base channels will be multiplied model_channels: int = 64 # output of the unet - # suggest: 3 + # suggest: `3` # you only need 6 if you also model the variance of the noise prediction # (usually we use an analytical variance hence 3) out_channels: int = 3 # how many repeating resblocks per resolution # the decoding side would have "one more" resblock - # default: 2 + # default: `2` num_res_blocks: int = 2 # you can also set the number of resblocks specifically for the input blocks - # default: None = above + # default: `None` = above num_input_res_blocks: int | None = None # number of time embed channels and style channels embed_channels: int = 512 # at what resolutions you want to do self-attention of the feature maps # attentions improve performance - # default: [16] - # beatgans: [32, 16, 8] + # default: `[16]` + # beatgans: `[32, 16, 8]` attention_resolutions: tuple[int] = (16,) # number of time embed channels time_embed_channels: int | None = None @@ -64,30 +65,30 @@ class BeatGANsUNetConfig(BaseConfig): # what's this? num_heads_upsample: int = -1 # use resblock for upscale/downscale blocks (expensive) - # default: True (BeatGANs) + # default: `True` (BeatGANs) resblock_updown: bool = True # never tried use_new_attention_order: bool = False resnet_two_cond: bool = False resnet_cond_channels: int | None = None # init the decoding conv layers with zero weights, this speeds up training - # default: True (BeattGANs) + # default: `True` (BeattGANs) resnet_use_zero_module: bool = True # a gradient checkpoint the attention operation attn_checkpoint: bool = False - def make_model(self): + def make_model(self) -> "BeatGANsUNetModel": return BeatGANsUNetModel(self) class BeatGANsUNetModel(nn.Module): - def __init__(self, conf: BeatGANsUNetConfig): + def __init__(self, conf: BeatGANsUNetConfig) -> None: super().__init__() - self.conf = conf + self.conf: BeatGANsUNetConfig = conf if conf.num_heads_upsample == -1: self.num_heads_upsample = conf.num_heads - self.dtype = th.float32 - self.time_emb_channels = conf.time_embed_channels or conf.model_channels + self.dtype: th.dtype = th.float32 + self.time_emb_channels: int = conf.time_embed_channels or conf.model_channels self.time_embed = nn.Sequential( nn.Linear(self.time_emb_channels, conf.embed_channels), nn.SiLU(), @@ -107,13 +108,15 @@ def __init__(self, conf: BeatGANsUNetConfig): "use_zero_module": conf.resnet_use_zero_module, "cond_emb_channels": conf.resnet_cond_channels, } - input_block_chans = [[] for _ in range(len(conf.channel_mult))] + input_block_chans: list[list[int]] = [ + [] for _ in range(len(conf.channel_mult)) + ] # TODO input_block_chans[0].append(ch) # number of blocks at each resolution - self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))] + self.input_num_blocks: list[int] = [0 for _ in range(len(conf.channel_mult))] self.input_num_blocks[0] = 1 - self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))] - resolution = conf.image_size + self.output_num_blocks: list[int] = [0 for _ in range(len(conf.channel_mult))] + resolution: int = conf.image_size for level, mult in enumerate(conf.input_channel_mult or conf.channel_mult): for _ in range(conf.num_input_res_blocks or conf.num_res_blocks): layers = [ @@ -165,7 +168,7 @@ def __init__(self, conf: BeatGANsUNetConfig): ), ), ) - ch = out_ch + ch: int = out_ch input_block_chans[level + 1].append(ch) self.input_num_blocks[level + 1] += 1 self.middle_block = TimestepEmbedSequential( @@ -266,12 +269,11 @@ def __init__(self, conf: BeatGANsUNetConfig): conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1), ) - def forward(self, x, t, y=None, **kwargs): + def forward(self, x, t, y=None, **kwargs) -> "Return": """Apply the model to an input batch.""" if (y is not None) != (self.conf.num_classes is not None): - raise ValueError( - "must specify y if and only if the model is class-conditional", - ) + msg = "must specify y if and only if the model is class-conditional" + raise ValueError(msg) hs = [[] for _ in range(len(self.conf.channel_mult))] emb = self.time_embed(timestep_embedding(t, self.time_emb_channels)) if self.conf.num_classes is not None: @@ -285,7 +287,8 @@ def forward(self, x, t, y=None, **kwargs): hs[i].append(h) k += 1 if k != len(self.input_blocks): - raise ValueError(f"expected {len(self.input_blocks)} blocks, got {k}") + msg: str = f"expected {len(self.input_blocks)} blocks, got {k}" + raise ValueError(msg) h = self.middle_block(h, emb=emb) k = 0 @@ -329,7 +332,7 @@ class BeatGANsEncoderConfig(BaseConfig): use_new_attention_order: bool = False pool: str = "adaptivenonzero" - def make_model(self): + def make_model(self) -> "BeatGANsEncoderModel": return BeatGANsEncoderModel(self) @@ -340,11 +343,11 @@ class BeatGANsEncoderModel(nn.Module): For usage, see UNet. """ - def __init__(self, conf: BeatGANsEncoderConfig): + def __init__(self, conf: BeatGANsEncoderConfig) -> None: super().__init__() - self.conf = conf - self.dtype = th.float32 - + self.conf: BeatGANsEncoderConfig = conf + self.dtype: th.dtype = th.float32 + time_embed_dim: int | None = None if conf.use_time_condition: time_embed_dim = conf.model_channels * 4 self.time_embed = nn.Sequential( @@ -363,8 +366,8 @@ def __init__(self, conf: BeatGANsEncoderConfig): ), ], ) - input_block_chans = [ch] - resolution = conf.image_size + input_block_chans: list[int] = [ch] + resolution: int = conf.image_size for level, mult in enumerate(conf.channel_mult): for _ in range(conf.num_res_blocks): layers = [ @@ -393,7 +396,7 @@ def __init__(self, conf: BeatGANsEncoderConfig): input_block_chans.append(ch) if level != len(conf.channel_mult) - 1: resolution //= 2 - out_ch = ch + out_ch: int = ch self.input_blocks.append( TimestepEmbedSequential( ResBlockConfig( @@ -415,7 +418,7 @@ def __init__(self, conf: BeatGANsEncoderConfig): ), ), ) - ch = out_ch + ch: int = out_ch input_block_chans.append(ch) self.middle_block = TimestepEmbedSequential( @@ -452,9 +455,10 @@ def __init__(self, conf: BeatGANsEncoderConfig): nn.Flatten(), ) else: - raise NotImplementedError(f"Unexpected {conf.pool} pooling") + msg: str = f"Unexpected {conf.pool} pooling" + raise NotImplementedError(msg) - def forward(self, x, t=None, return_2d_feature=False): + def forward(self, x, t=None, return_2d_feature: bool = False): """Apply the model to an input batch.""" if self.conf.use_time_condition: emb = self.time_embed(timestep_embedding(t, self.model_channels)) diff --git a/src/visualizr/anitalker/model/unet_autoenc.py b/src/visualizr/anitalker/model/unet_autoenc.py index 93b05708..78fdd133 100644 --- a/src/visualizr/anitalker/model/unet_autoenc.py +++ b/src/visualizr/anitalker/model/unet_autoenc.py @@ -25,12 +25,12 @@ class BeatGANsAutoencConfig(BeatGANsUNetConfig): class BeatGANsAutoencModel(BeatGANsUNetModel): - def __init__(self, conf: BeatGANsAutoencConfig): + def __init__(self, conf: BeatGANsAutoencConfig) -> None: super().__init__(conf) - self.conf = conf + self.conf: BeatGANsAutoencConfig = conf # having only time, cond - self.time_embed = TimeStyleSeperateEmbed( + self.time_embed: TimeStyleSeperateEmbed = TimeStyleSeperateEmbed( time_channels=conf.model_channels, time_out_channels=conf.embed_channels, ) @@ -75,7 +75,7 @@ def forward( noise=None, t_cond=None, **kwargs, - ): + ) -> "AutoencReturn": """ Apply the model to an input batch. @@ -91,10 +91,16 @@ def forward( cond = self.noise_to_cond(noise) if cond is None: if x is not None and len(x) != len(x_start): - raise ValueError(f"{len(x)} != {len(x_start)}") + msg: str = f"{len(x)} != {len(x_start)}" + raise ValueError(msg) tmp = self.encode(x_start) cond = tmp["cond"] + + _t_emb: Tensor | None = None + _t_cond_emb: Tensor | None = None + emb: Tensor | None = None + cond_emb: Tensor | None = None if t is not None: _t_emb = timestep_embedding(t, self.conf.model_channels) _t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels) @@ -103,7 +109,7 @@ def forward( _t_emb = None _t_cond_emb = None if self.conf.resnet_two_cond: - res = self.time_embed.forward(time_emb=_t_emb, cond=cond) + res: EmbedReturn = self.time_embed.forward(time_emb=_t_emb, cond=cond) else: raise NotImplementedError if self.conf.resnet_two_cond: @@ -123,15 +129,17 @@ def forward( raise NotImplementedError # where in the model to supply time conditions - enc_time_emb = emb - mid_time_emb = emb - dec_time_emb = emb + enc_time_emb: Tensor | None = emb + mid_time_emb: Tensor | None = emb + dec_time_emb: Tensor | None = emb # where in the model to supply style conditions - enc_cond_emb = cond_emb - mid_cond_emb = cond_emb - dec_cond_emb = cond_emb + enc_cond_emb: Tensor | None = cond_emb + mid_cond_emb: Tensor | None = cond_emb + dec_cond_emb: Tensor | None = cond_emb - hs = [[] for _ in range(len(self.conf.channel_mult))] + hs: list[list[Tensor]] = [ + [] for _ in range(len(self.conf.channel_mult)) + ] # TODO: type if x is not None: h = x.type(self.dtype) @@ -160,6 +168,7 @@ def forward( for _ in range(self.output_num_blocks[i]): # take the lateral connection from the same layer (in reserve) # until there is no more, use None. + lateral: Tensor | None = None try: lateral = hs[-i - 1].pop() except IndexError: @@ -179,21 +188,21 @@ def forward( class AutoencReturn(NamedTuple): pred: Tensor - cond: Tensor = None + cond: Tensor | None = None class EmbedReturn(NamedTuple): # style and time - emb: Tensor = None + emb: Tensor | None = None # time only - time_emb: Tensor = None + time_emb: Tensor | None = None # style only (but could depend on time) - style: Tensor = None + style: Tensor | None = None class TimeStyleSeperateEmbed(nn.Module): # embed only style - def __init__(self, time_channels, time_out_channels): + def __init__(self, time_channels, time_out_channels) -> None: super().__init__() self.time_embed = nn.Sequential( nn.Linear(time_channels, time_out_channels), @@ -202,7 +211,7 @@ def __init__(self, time_channels, time_out_channels): ) self.style = nn.Identity() - def forward(self, time_emb=None, cond=None): + def forward(self, time_emb=None, cond=None) -> EmbedReturn: time_emb = None if time_emb is None else self.time_embed(time_emb) style = self.style(cond) return EmbedReturn(emb=style, time_emb=time_emb, style=style) diff --git a/src/visualizr/anitalker/networks/discriminator.py b/src/visualizr/anitalker/networks/discriminator.py index d2ecd2ea..37605f86 100644 --- a/src/visualizr/anitalker/networks/discriminator.py +++ b/src/visualizr/anitalker/networks/discriminator.py @@ -1,41 +1,52 @@ import math +from regex import T import torch -from torch import nn +from torch import Tensor, nn from torch.nn.functional import conv2d, leaky_relu, pad -def fused_leaky_relu(_input, bias, negative_slope=0.2, scale=2**0.5): +def fused_leaky_relu( + _input: Tensor, + bias: nn.Parameter, + negative_slope: float = 0.2, + scale: float = 2**0.5, +) -> Tensor: return leaky_relu(_input + bias, negative_slope) * scale class FusedLeakyReLU(nn.Module): - def __init__(self, channel, negative_slope=0.2, scale=2**0.5): + def __init__( + self, + channel, + negative_slope: float = 0.2, + scale: float = 2**0.5, + ) -> None: super().__init__() self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1)) - self.negative_slope = negative_slope - self.scale = scale + self.negative_slope: float = negative_slope + self.scale: float = scale - def forward(self, _input): + def forward(self, _input: Tensor) -> Tensor: return fused_leaky_relu(_input, self.bias, self.negative_slope, self.scale) def upfirdn2d_native( - _input, - kernel, - up_x, - up_y, - down_x, - down_y, - pad_x0, - pad_x1, - pad_y0, - pad_y1, -): + _input: Tensor, + kernel: Tensor, + up_x: int, + up_y: int, + down_x: int, + down_y: int, + pad_x0: int, + pad_x1: int, + pad_y0: int, + pad_y1: int, +) -> Tensor: _, minor, in_h, in_w = _input.shape kernel_h, kernel_w = kernel.shape - out = _input.view(-1, minor, in_h, 1, in_w, 1) + out: Tensor = _input.view(-1, minor, in_h, 1, in_w, 1) out = pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0]) out = out.view(-1, minor, in_h * up_y, in_w * up_x) @@ -50,7 +61,7 @@ def upfirdn2d_native( out = out.reshape( [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1], ) - w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + w: Tensor = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) out = conv2d(out, w) out = out.reshape( -1, @@ -61,7 +72,13 @@ def upfirdn2d_native( return out[:, :, ::down_y, ::down_x] -def upfirdn2d(_input, kernel, up=1, down=1, _pad=(0, 0)): +def upfirdn2d( + _input: Tensor, + kernel: Tensor, + up: int = 1, + down: int = 1, + _pad: tuple[int] = (0, 0), +) -> Tensor: return upfirdn2d_native( _input, kernel, @@ -76,41 +93,38 @@ def upfirdn2d(_input, kernel, up=1, down=1, _pad=(0, 0)): ) -def make_kernel(k): - k = torch.tensor(k, dtype=torch.float32) - +def make_kernel(k: Tensor) -> Tensor: + k: Tensor = torch.tensor(k, dtype=torch.float32) if k.ndim == 1: k = k[None, :] * k[:, None] - k /= k.sum() - return k class Blur(nn.Module): - def __init__(self, kernel, _pad, upsample_factor=1): + def __init__( + self, + kernel: Tensor, + _pad: tuple[int], + upsample_factor: int = 1, + ) -> None: super().__init__() - - kernel = make_kernel(kernel) - + kernel: Tensor = make_kernel(kernel) if upsample_factor > 1: kernel = kernel * (upsample_factor**2) - self.register_buffer("kernel", kernel) + self.pad: tuple[int] = _pad - self.pad = _pad - - def forward(self, _input): + def forward(self, _input: Tensor) -> Tensor: return upfirdn2d(_input, self.kernel, _pad=self.pad) class ScaledLeakyReLU(nn.Module): - def __init__(self, negative_slope=0.2): + def __init__(self, negative_slope: float = 0.2) -> None: super().__init__() + self.negative_slope: float = negative_slope - self.negative_slope = negative_slope - - def forward(self, _input): + def forward(self, _input: Tensor) -> Tensor: return leaky_relu(_input, negative_slope=self.negative_slope) @@ -120,23 +134,22 @@ def __init__( in_channel, out_channel, kernel_size, - stride=1, - padding=0, - bias=True, - ): + stride: int = 1, + padding: int = 0, + bias: bool = True, + ) -> None: super().__init__() - self.weight = nn.Parameter( torch.randn(out_channel, in_channel, kernel_size, kernel_size), ) - self.scale = 1 / math.sqrt(in_channel * kernel_size**2) - - self.stride = stride - self.padding = padding - - self.bias = nn.Parameter(torch.zeros(out_channel)) if bias else None + self.scale: float = 1 / math.sqrt(in_channel * kernel_size**2) + self.stride: int = stride + self.padding: int = padding + self.bias: nn.Parameter | None = ( + nn.Parameter(torch.zeros(out_channel)) if bias else None + ) - def forward(self, _input): + def forward(self, _input: Tensor) -> Tensor: return conv2d( _input, self.weight * self.scale, @@ -152,23 +165,21 @@ def __init__( in_channel, out_channel, kernel_size, - downsample=False, - blur_kernel: list = None, - bias=True, - activate=True, - ): + downsample: bool = False, + blur_kernel: list | None = None, + bias: bool = True, + activate: bool = True, + ) -> None: if blur_kernel is None: blur_kernel = [1, 3, 3, 1] - layers = [] + layers: list[nn.Module] = [] if downsample: factor = 2 p = (len(blur_kernel) - factor) + (kernel_size - 1) pad0 = (p + 1) // 2 pad1 = p // 2 - layers.append(Blur(blur_kernel, _pad=(pad0, pad1))) - stride = 2 self.padding = 0 diff --git a/src/visualizr/anitalker/networks/encoder.py b/src/visualizr/anitalker/networks/encoder.py index 24cf9385..68c9b723 100644 --- a/src/visualizr/anitalker/networks/encoder.py +++ b/src/visualizr/anitalker/networks/encoder.py @@ -7,37 +7,47 @@ from visualizr.app.logger import logger -def fused_leaky_relu(_input, bias, negative_slope=0.2, scale=2**0.5): +def fused_leaky_relu( + _input: Tensor, + bias: nn.Parameter, + negative_slope: float = 0.2, + scale: float = 2**0.5, +) -> Tensor: return leaky_relu(_input + bias, negative_slope) * scale class FusedLeakyReLU(nn.Module): - def __init__(self, channel, negative_slope=0.2, scale=2**0.5): + def __init__( + self, + channel, + negative_slope: float = 0.2, + scale: float = 2**0.5, + ) -> None: super().__init__() self.bias = nn.Parameter(zeros(1, channel, 1, 1)) - self.negative_slope = negative_slope - self.scale = scale + self.negative_slope: float = negative_slope + self.scale: float = scale - def forward(self, _input): + def forward(self, _input: Tensor) -> Tensor: return fused_leaky_relu(_input, self.bias, self.negative_slope, self.scale) def upfirdn2d_native( - _input, - kernel, - up_x, - up_y, - down_x, - down_y, - pad_x0, - pad_x1, - pad_y0, - pad_y1, -): + _input: Tensor, + kernel: Tensor, + up_x: int, + up_y: int, + down_x: int, + down_y: int, + pad_x0: int, + pad_x1: int, + pad_y0: int, + pad_y1: int, +) -> Tensor: _, minor, in_h, in_w = _input.shape kernel_h, kernel_w = kernel.shape - out = _input.view(-1, minor, in_h, 1, in_w, 1) + out: Tensor = _input.view(-1, minor, in_h, 1, in_w, 1) out = pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0]) out = out.view(-1, minor, in_h * up_y, in_w * up_x) @@ -52,7 +62,7 @@ def upfirdn2d_native( out = out.reshape( [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1], ) - w = flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + w: Tensor = flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) out = conv2d(out, w) out = out.reshape( -1, @@ -64,7 +74,13 @@ def upfirdn2d_native( return out[:, :, ::down_y, ::down_x] -def upfirdn2d(_input, kernel, up=1, down=1, _pad=(0, 0)): +def upfirdn2d( + _input: Tensor, + kernel: Tensor, + up: int = 1, + down: int = 1, + _pad: tuple[int] = (0, 0), +) -> Tensor: return upfirdn2d_native( _input, kernel, @@ -79,41 +95,38 @@ def upfirdn2d(_input, kernel, up=1, down=1, _pad=(0, 0)): ) -def make_kernel(k): - k = tensor(k, dtype=float32) - +def make_kernel(k: Tensor) -> Tensor: + k: Tensor = tensor(k, dtype=float32) if k.ndim == 1: k = k[None, :] * k[:, None] - k /= k.sum() - return k class Blur(nn.Module): - def __init__(self, kernel, _pad, upsample_factor=1): + def __init__( + self, + kernel: Tensor, + _pad: tuple[int], + upsample_factor: int = 1, + ) -> None: super().__init__() - - kernel = make_kernel(kernel) - + kernel: Tensor = make_kernel(kernel) if upsample_factor > 1: kernel = kernel * (upsample_factor**2) - self.register_buffer("kernel", kernel) + self.pad: tuple[int] = _pad - self.pad = _pad - - def forward(self, _input): + def forward(self, _input: Tensor) -> Tensor: return upfirdn2d(_input, self.kernel, _pad=self.pad) class ScaledLeakyReLU(nn.Module): - def __init__(self, negative_slope=0.2): + def __init__(self, negative_slope: float = 0.2) -> None: super().__init__() + self.negative_slope: float = negative_slope - self.negative_slope = negative_slope - - def forward(self, _input): + def forward(self, _input: Tensor) -> Tensor: return leaky_relu(_input, negative_slope=self.negative_slope) @@ -123,23 +136,22 @@ def __init__( in_channel, out_channel, kernel_size, - stride=1, - padding=0, - bias=True, - ): + stride: int = 1, + padding: int = 0, + bias: bool = True, + ) -> None: super().__init__() - self.weight = nn.Parameter( randn(out_channel, in_channel, kernel_size, kernel_size), ) - self.scale = 1 / sqrt(in_channel * kernel_size**2) - - self.stride = stride - self.padding = padding - - self.bias = nn.Parameter(zeros(out_channel)) if bias else None + self.scale: float = 1 / sqrt(in_channel * kernel_size**2) + self.stride: int = stride + self.padding: int = padding + self.bias: nn.Parameter | None = ( + nn.Parameter(zeros(out_channel)) if bias else None + ) - def forward(self, _input): + def forward(self, _input: Tensor) -> Tensor: return conv2d( _input, self.weight * self.scale, @@ -148,7 +160,7 @@ def forward(self, _input): padding=self.padding, ) - def __repr__(self): + def __repr__(self) -> str: return ( f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" @@ -160,30 +172,29 @@ def __init__( self, in_dim, out_dim, - bias=True, - bias_init=0, - lr_mul=1, + bias: bool = True, + bias_init: int = 0, + lr_mul: int = 1, activation=None, - ): + ) -> None: super().__init__() - self.weight = nn.Parameter(randn(out_dim, in_dim).div_(lr_mul)) - self.bias = nn.Parameter(zeros(out_dim).fill_(bias_init)) if bias else None + self.bias: nn.Parameter | None = ( + nn.Parameter(zeros(out_dim).fill_(bias_init)) if bias else None + ) self.activation = activation + self.scale: float = (1 / sqrt(in_dim)) * lr_mul + self.lr_mul: int = lr_mul - self.scale = (1 / sqrt(in_dim)) * lr_mul - self.lr_mul = lr_mul - - def forward(self, _input): + def forward(self, _input: Tensor) -> Tensor: if self.activation: out = linear(_input, self.weight * self.scale) out = fused_leaky_relu(out, self.bias * self.lr_mul) else: out = linear(_input, self.weight * self.scale, bias=self.bias * self.lr_mul) - return out - def __repr__(self): + def __repr__(self) -> str: return ( f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" ) @@ -195,26 +206,23 @@ def __init__( in_channel, out_channel, kernel_size, - downsample=False, - blur_kernel: list = None, - bias=True, - activate=True, - ): + downsample: bool = False, + blur_kernel: list | None = None, + bias: bool = True, + activate: bool = True, + ) -> None: if blur_kernel is None: blur_kernel = [1, 3, 3, 1] - layers = [] + layers: list[nn.Module] = [] if downsample: factor = 2 p = (len(blur_kernel) - factor) + (kernel_size - 1) pad0 = (p + 1) // 2 pad1 = p // 2 - layers.append(Blur(blur_kernel, _pad=(pad0, pad1))) - stride = 2 self.padding = 0 - else: stride = 1 self.padding = kernel_size // 2 @@ -240,7 +248,7 @@ def __init__( class ResBlock(nn.Module): - def __init__(self, in_channel, out_channel): + def __init__(self, in_channel, out_channel) -> None: super().__init__() self.conv1 = ConvLayer(in_channel, in_channel, 3) @@ -264,12 +272,12 @@ def forward(self, _input): class WeightedSumLayer(nn.Module): - def __init__(self, num_tensors=8): + def __init__(self, num_tensors: int = 8) -> None: super().__init__() self.weights = nn.Parameter(randn(num_tensors)) - def forward(self, tensor_list): - weights = softmax(self.weights, dim=0) + def forward(self, tensor_list: list[Tensor]) -> Tensor: + weights: Tensor = softmax(self.weights, dim=0) weighted_sum: Tensor = zeros_like(tensor_list[0]) for _tensor, weight in zip(tensor_list, weights, strict=False): weighted_sum += _tensor * weight @@ -277,10 +285,10 @@ def forward(self, tensor_list): class EncoderApp(nn.Module): - def __init__(self, size, w_dim=512, fusion_type=""): + def __init__(self, size, w_dim: int = 512, fusion_type: str = "") -> None: super().__init__() - channels = { + channels: dict[int, int] = { 4: 512, 8: 512, 16: 512, @@ -292,27 +300,28 @@ def __init__(self, size, w_dim=512, fusion_type=""): 1024: 16, } - self.w_dim = w_dim + self.w_dim: int = w_dim log_size = int(log(size, 2)) self.convs = nn.ModuleList() self.convs.append(ConvLayer(3, channels[size], 1)) - in_channel = channels[size] + in_channel: int = channels[size] for i in range(log_size, 2, -1): - out_channel = channels[2 ** (i - 1)] + out_channel: int = channels[2 ** (i - 1)] self.convs.append(ResBlock(in_channel, out_channel)) in_channel = out_channel self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, bias=False)) - self.fusion_type = fusion_type + self.fusion_type: str = fusion_type if self.fusion_type != "weighted_sum": - raise ValueError( + msg: str = ( f"Unsupported `fusion_type`: {self.fusion_type}. " - "Expected 'weighted_sum'.", + "Expected 'weighted_sum'." ) + raise ValueError(msg) _msg = "HAL layer is enabled!" logger.info(_msg) Info(_msg) @@ -350,7 +359,7 @@ def forward(self, x): class DecouplingModel(nn.Module): - def __init__(self, input_dim, hidden_dim, output_dim): + def __init__(self, input_dim, hidden_dim, output_dim) -> None: super().__init__() # identity_excluded_net is called identity encoder in the paper @@ -380,17 +389,23 @@ def forward(self, x): class Encoder(nn.Module): - def __init__(self, size, dim=512, dim_motion=20, weighted_sum=False): + def __init__( + self, + size, + dim: int = 512, + dim_motion: int = 20, + weighted_sum: bool = False, + ) -> None: super().__init__() # image encoder self.net_app = EncoderApp(size, dim, weighted_sum) # decouping network - self.net_decouping = DecouplingModel(dim, dim, dim) + self.net_decouping: DecouplingModel = DecouplingModel(dim, dim, dim) # part of the motion encoder - fc = [EqualLinear(dim, dim)] + fc: list[EqualLinear] = [EqualLinear(dim, dim)] fc.extend(EqualLinear(dim, dim) for _ in range(3)) fc.append(EqualLinear(dim, dim_motion)) self.fc = nn.Sequential(*fc) diff --git a/src/visualizr/anitalker/networks/styledecoder.py b/src/visualizr/anitalker/networks/styledecoder.py index 9b54f418..c84b7cd0 100644 --- a/src/visualizr/anitalker/networks/styledecoder.py +++ b/src/visualizr/anitalker/networks/styledecoder.py @@ -18,7 +18,12 @@ ) -def fused_leaky_relu(_input, bias, negative_slope=0.2, scale=2**0.5): +def fused_leaky_relu( + _input: Tensor, + bias: nn.Parameter, + negative_slope: float = 0.2, + scale: float = 2**0.5, +) -> Tensor: """ Apply fused leaky ReLU activation with bias and scaling. @@ -35,11 +40,16 @@ def fused_leaky_relu(_input, bias, negative_slope=0.2, scale=2**0.5): class FusedLeakyReLU(nn.Module): - def __init__(self, channel, negative_slope=0.2, scale=2**0.5): + def __init__( + self, + channel, + negative_slope: float = 0.2, + scale: float = 2**0.5, + ) -> None: super().__init__() self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1)) - self.negative_slope = negative_slope - self.scale = scale + self.negative_slope: float = negative_slope + self.scale: float = scale def forward(self, _input): """ @@ -55,17 +65,17 @@ def forward(self, _input): def upfirdn2d_native( - _input, - kernel, - up_x, - up_y, - down_x, - down_y, - pad_x0, - pad_x1, - pad_y0, - pad_y1, -): + _input: Tensor, + kernel: Tensor, + up_x: int, + up_y: int, + down_x: int, + down_y: int, + pad_x0: int, + pad_x1: int, + pad_y0: int, + pad_y1: int, +) -> Tensor: """ Perform upsample, FIR filter, and downsample on 2D input tensor. @@ -87,7 +97,7 @@ def upfirdn2d_native( _, minor, in_h, in_w = _input.shape kernel_h, kernel_w = kernel.shape - out = _input.view(-1, minor, in_h, 1, in_w, 1) + out: Tensor = _input.view(-1, minor, in_h, 1, in_w, 1) out = pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0]) out = out.view(-1, minor, in_h * up_y, in_w * up_x) @@ -102,7 +112,7 @@ def upfirdn2d_native( out = out.reshape( [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1], ) - w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + w: Tensor = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) out = conv2d(out, w) out = out.reshape( -1, @@ -113,7 +123,13 @@ def upfirdn2d_native( return out[:, :, ::down_y, ::down_x] -def upfirdn2d(_input, kernel, up=1, down=1, _pad=(0, 0)): +def upfirdn2d( + _input: Tensor, + kernel: Tensor, + up: int = 1, + down: int = 1, + _pad: tuple[int] = (0, 0), +) -> Tensor: """ Wrapper for upfirdn2d_native with same up/down and symmetric padding. @@ -141,59 +157,48 @@ def upfirdn2d(_input, kernel, up=1, down=1, _pad=(0, 0)): ) -def make_kernel(k): +def make_kernel(k: Tensor) -> Tensor: # TODO: type hint """ Create a 2D convolution kernel tensor from 1D or 2D input. Args: - k (list or Tensor): 1D or 2D kernel coefficients. + k (Tensor): 1D or 2D kernel coefficients. Returns: Tensor: Normalized 2D kernel tensor. """ - k = torch.tensor(k, dtype=torch.float32) - + k: Tensor = torch.tensor(k, dtype=torch.float32) if k.ndim == 1: k = k[None, :] * k[:, None] - k /= k.sum() - return k class Upsample(nn.Module): - def __init__(self, kernel, factor=2): + def __init__(self, kernel, factor: int = 2) -> None: super().__init__() - - self.factor = factor - kernel = make_kernel(kernel) * (factor**2) + self.factor: int = factor + kernel: Tensor = make_kernel(kernel) * (factor**2) self.register_buffer("kernel", kernel) + p: int = kernel.shape[0] - factor + pad0: int = (p + 1) // 2 + factor - 1 + pad1: int = p // 2 + self.pad: tuple[int] = (pad0, pad1) - p = kernel.shape[0] - factor - - pad0 = (p + 1) // 2 + factor - 1 - pad1 = p // 2 - - self.pad = (pad0, pad1) - - def forward(self, _input): + def forward(self, _input: Tensor) -> Tensor: return upfirdn2d(_input, self.kernel, up=self.factor, _pad=self.pad) class Blur(nn.Module): - def __init__(self, kernel, _pad, upsample_factor=1): + def __init__(self, kernel, _pad, upsample_factor: int = 1) -> None: super().__init__() - - kernel = make_kernel(kernel) - + kernel: Tensor = make_kernel(kernel) if upsample_factor > 1: kernel = kernel * (upsample_factor**2) - self.register_buffer("kernel", kernel) - self.pad = _pad - def forward(self, _input): + def forward(self, _input: Tensor) -> Tensor: return upfirdn2d(_input, self.kernel, _pad=self.pad) @@ -203,23 +208,24 @@ def __init__( in_channel, out_channel, kernel_size, - stride=1, - padding=0, - bias=True, - ): + stride: int = 1, + padding: int = 0, + bias: bool = True, + ) -> None: super().__init__() self.weight = nn.Parameter( torch.randn(out_channel, in_channel, kernel_size, kernel_size), ) - self.scale = 1 / math.sqrt(in_channel * kernel_size**2) - - self.stride = stride - self.padding = padding + self.scale: float = 1 / math.sqrt(in_channel * kernel_size**2) + self.stride: int = stride + self.padding: int = padding - self.bias = nn.Parameter(torch.zeros(out_channel)) if bias else None + self.bias: nn.Parameter | None = ( + nn.Parameter(torch.zeros(out_channel)) if bias else None + ) - def forward(self, _input): + def forward(self, _input: Tensor) -> Tensor: """ Apply equalized learning rate 2D convolution to the input tensor. @@ -249,11 +255,11 @@ def __init__( self, in_dim, out_dim, - bias=True, - bias_init=0, - lr_mul=1, + bias: bool = True, + bias_init: int = 0, + lr_mul: int = 1, activation=None, - ): + ) -> None: """ Initialize the EqualLinear layer. @@ -276,10 +282,10 @@ def __init__( self.activation = activation - self.scale = (1 / math.sqrt(in_dim)) * lr_mul - self.lr_mul = lr_mul + self.scale: float = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul: int = lr_mul - def forward(self, _input): + def forward(self, _input: Tensor) -> Tensor: """ Perform the forward pass of the EqualLinear layer. @@ -296,7 +302,7 @@ def forward(self, _input): out = linear(_input, self.weight * self.scale, self.bias * self.lr_mul) return out - def __repr__(self): + def __repr__(self) -> str: """ Return a string representation of the EqualLinear layer. @@ -309,7 +315,7 @@ def __repr__(self): class ScaledLeakyReLU(nn.Module): - def __init__(self, negative_slope=0.2): + def __init__(self, negative_slope: float = 0.2) -> None: """ Initialize the ScaledLeakyReLU activation module. @@ -319,9 +325,9 @@ def __init__(self, negative_slope=0.2): """ super().__init__() - self.negative_slope = negative_slope + self.negative_slope: float = negative_slope - def forward(self, _input): + def forward(self, _input: Tensor) -> Tensor: """ Apply the scaled LeakyReLU activation to the input tensor. @@ -341,21 +347,21 @@ def __init__( out_channel, kernel_size, style_dim, - demodulate=True, - upsample=False, - downsample=False, - blur_kernel: list = None, - ): + demodulate: bool = True, + upsample: bool = False, + downsample: bool = False, + blur_kernel: list | None = None, + ) -> None: if blur_kernel is None: blur_kernel = [1, 3, 3, 1] super().__init__() - self.eps = 1e-8 + self.eps: float = 1e-8 self.kernel_size = kernel_size self.in_channel = in_channel self.out_channel = out_channel - self.upsample = upsample - self.downsample = downsample + self.upsample: bool = upsample + self.downsample: bool = downsample if upsample: factor = 2 @@ -374,7 +380,7 @@ def __init__( self.blur = Blur(blur_kernel, _pad=(pad0, pad1)) fan_in = in_channel * kernel_size**2 - self.scale = 1 / math.sqrt(fan_in) + self.scale: float = 1 / math.sqrt(fan_in) self.padding = kernel_size // 2 self.weight = nn.Parameter( @@ -382,16 +388,15 @@ def __init__( ) self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) - self.demodulate = demodulate + self.demodulate: bool = demodulate - def forward(self, _input, style): + def forward(self, _input: Tensor, style: Tensor) -> Tensor: batch, in_channel, height, width = _input.shape - style = self.modulation(style).view(batch, 1, in_channel, 1, 1) weight: Tensor = self.scale * self.weight * style if self.demodulate: - demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + demod: Tensor = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) weight = weight.view( @@ -432,14 +437,12 @@ def forward(self, _input, style): out = conv2d(_input, weight, padding=self.padding, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) - return out class NoiseInjection(nn.Module): - def __init__(self): + def __init__(self) -> None: super().__init__() - self.weight = nn.Parameter(torch.zeros(1)) def forward(self, image, noise=None): @@ -447,12 +450,11 @@ def forward(self, image, noise=None): class ConstantInput(nn.Module): - def __init__(self, channel, size=4): + def __init__(self, channel, size: int = 4) -> None: super().__init__() - self.input = nn.Parameter(torch.randn(1, channel, size, size)) - def forward(self, _input): + def forward(self, _input: Tensor) -> Tensor: batch = _input.shape[0] return self.input.repeat(batch, 1, 1, 1) @@ -464,10 +466,10 @@ def __init__( out_channel, kernel_size, style_dim, - upsample=False, - blur_kernel: list = None, - demodulate=True, - ): + upsample: bool = False, + blur_kernel: list | None = None, + demodulate: bool = True, + ) -> None: if blur_kernel is None: blur_kernel = [1, 3, 3, 1] super().__init__() @@ -485,7 +487,7 @@ def __init__( self.noise = NoiseInjection() self.activate = FusedLeakyReLU(out_channel) - def forward(self, _input, style, noise=None): + def forward(self, _input, style, noise=None): # TODO: type hint out = self.conv(_input, style) out = self.noise(out, noise=noise) out = self.activate(out) @@ -498,26 +500,23 @@ def __init__( in_channel, out_channel, kernel_size, - downsample=False, - blur_kernel: list = None, - bias=True, - activate=True, - ): + downsample: bool = False, + blur_kernel: list | None = None, + bias: bool = True, + activate: bool = True, + ) -> None: if blur_kernel is None: blur_kernel = [1, 3, 3, 1] - layers = [] + layers: list[nn.Module] = [] if downsample: factor = 2 p = (len(blur_kernel) - factor) + (kernel_size - 1) pad0 = (p + 1) // 2 pad1 = p // 2 - layers.append(Blur(blur_kernel, _pad=(pad0, pad1))) - stride = 2 self.padding = 0 - else: stride = 1 self.padding = kernel_size // 2 @@ -543,7 +542,12 @@ def __init__( class ToRGB(nn.Module): - def __init__(self, in_channel, upsample=True, blur_kernel: list = None): + def __init__( + self, + in_channel, + upsample: bool = True, + blur_kernel: list | None = None, + ) -> None: if blur_kernel is None: blur_kernel = [1, 3, 3, 1] super().__init__() @@ -564,7 +568,13 @@ def forward(self, _input, skip=None): class ToFlow(nn.Module): - def __init__(self, in_channel, style_dim, upsample=True, blur_kernel: list = None): + def __init__( + self, + in_channel, + style_dim, + upsample: bool = True, + blur_kernel: list | None = None, + ) -> None: if blur_kernel is None: blur_kernel = [1, 3, 3, 1] super().__init__() @@ -583,7 +593,6 @@ def forward(self, _input, style, feat, skip=None): # warping xs = np.linspace(-1, 1, _input.size(2)) - xs = np.meshgrid(xs, xs) xs = np.stack(xs, 2) @@ -598,28 +607,27 @@ def forward(self, _input, style, feat, skip=None): skip = self.upsample(skip) out = out + skip - sampler = torch.tanh(out[:, 0:2, :, :]) - mask = torch.sigmoid(out[:, 2:3, :, :]) - flow = sampler.permute(0, 2, 3, 1) + xs - feat_warp = grid_sample(feat, flow) * mask + sampler: Tensor = torch.tanh(out[:, 0:2, :, :]) + mask: Tensor = torch.sigmoid(out[:, 2:3, :, :]) + flow: Tensor = sampler.permute(0, 2, 3, 1) + xs + feat_warp: Tensor = grid_sample(feat, flow) * mask return feat_warp, feat_warp + _input * (1.0 - mask), out class Direction(nn.Module): - def __init__(self, motion_dim): + def __init__(self, motion_dim) -> None: super().__init__() - self.weight = nn.Parameter(torch.randn(512, motion_dim)) - def forward(self, _input): + def forward(self, _input: Tensor) -> Tensor: # input: (bs*t) x 512 - weight = self.weight + 1e-8 + weight: Tensor = self.weight + 1e-8 # get eigenvector, orthogonal [n1, n2, n3, n4] q, _ = torch.linalg.qr(weight) if _input is None: return q - input_diag = torch.diag_embed(_input) # alpha, diagonal matrix - out = torch.matmul(input_diag, q.T) + input_diag: Tensor = torch.diag_embed(_input) # alpha, diagonal matrix + out: Tensor = torch.matmul(input_diag, q.T) out = torch.sum(out, dim=1) return out @@ -630,9 +638,9 @@ def __init__( size, style_dim, motion_dim, - blur_kernel: list = None, - channel_multiplier=1, - ): + blur_kernel: list | None = None, + channel_multiplier: int = 1, + ) -> None: if blur_kernel is None: blur_kernel = [1, 3, 3, 1] super().__init__() @@ -641,9 +649,9 @@ def __init__( self.style_dim = style_dim self.motion_dim = motion_dim # Linear Motion Decomposition (LMD) from LIA - self.direction = Direction(motion_dim) + self.direction: Direction = Direction(motion_dim) - self.channels = { + self.channels: dict[int, int] = { 4: 512, 8: 512, 16: 512, @@ -665,16 +673,16 @@ def __init__( ) self.log_size = int(math.log(size, 2)) - self.num_layers = (self.log_size - 2) * 2 + 1 + self.num_layers: int = (self.log_size - 2) * 2 + 1 self.convs = nn.ModuleList() self.to_rgbs = nn.ModuleList() self.to_flows = nn.ModuleList() - in_channel = self.channels[4] + in_channel: int = self.channels[4] for i in range(3, self.log_size + 1): - out_channel = self.channels[2**i] + out_channel: int = self.channels[2**i] self.convs.append( StyledConv( @@ -696,12 +704,10 @@ def __init__( ), ) self.to_rgbs.append(ToRGB(out_channel)) - self.to_flows.append(ToFlow(out_channel, style_dim)) - in_channel = out_channel - self.n_latent = self.log_size * 2 - 2 + self.n_latent: int = self.log_size * 2 - 2 def forward(self, source_before_decoupling, target_motion, feats): skip_flow = None @@ -709,7 +715,7 @@ def forward(self, source_before_decoupling, target_motion, feats): directions = self.direction(target_motion) latent = source_before_decoupling + directions # wa + directions - inject_index = self.n_latent + inject_index: int = self.n_latent latent = latent.unsqueeze(1).repeat(1, inject_index, 1) out = self.input(latent) diff --git a/src/visualizr/anitalker/renderer.py b/src/visualizr/anitalker/renderer.py index 8648d10f..70fb5c9b 100644 --- a/src/visualizr/anitalker/renderer.py +++ b/src/visualizr/anitalker/renderer.py @@ -44,7 +44,7 @@ def render_condition( ValueError: If the model type is not autoencoder-capable. """ if conf.train_mode != TrainMode.diffusion: - raise NotImplementedError() + raise NotImplementedError if not conf.model_type.has_autoenc(): msg: str = ( "TrainMode.diffusion requires an " diff --git a/src/visualizr/anitalker/templates.py b/src/visualizr/anitalker/templates.py index 60418a6d..f1d7f4d1 100644 --- a/src/visualizr/anitalker/templates.py +++ b/src/visualizr/anitalker/templates.py @@ -2,7 +2,7 @@ from visualizr.anitalker.config import TrainConfig -def autoenc_base(): +def autoenc_base() -> TrainConfig: """Return base configuration for all Diff-AE models.""" conf = TrainConfig() conf.batch_size = 32 @@ -27,8 +27,8 @@ def autoenc_base(): return conf -def ffhq128_autoenc_base(): - conf = autoenc_base() +def ffhq128_autoenc_base() -> TrainConfig: + conf: TrainConfig = autoenc_base() conf.data_name = "ffhqlmdb256" conf.scale_up_gpus(4) conf.img_size = 128 @@ -41,8 +41,8 @@ def ffhq128_autoenc_base(): return conf -def ffhq256_autoenc(): - conf = ffhq128_autoenc_base() +def ffhq256_autoenc() -> TrainConfig: + conf: TrainConfig = ffhq128_autoenc_base() conf.img_size = 256 conf.net_ch = 128 conf.net_ch_mult = (1, 1, 2, 2, 4, 4) diff --git a/src/visualizr/anitalker/utils.py b/src/visualizr/anitalker/utils.py index 1bdd7998..e11d84f9 100644 --- a/src/visualizr/anitalker/utils.py +++ b/src/visualizr/anitalker/utils.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Literal +from typing import TYPE_CHECKING from gradio import Error, Info from imageio import mimsave @@ -10,7 +10,7 @@ concatenate_videoclips, ) from numpy import asarray, ndarray, transpose -from PIL import Image +from PIL.Image import Image, open as pil_open from torch import Tensor, from_numpy, load as torch_load from torchvision.transforms import ToPILImage @@ -19,6 +19,11 @@ from visualizr.anitalker.face_sr.face_enhancer import enhancer_list from visualizr.anitalker.templates import ffhq256_autoenc from visualizr.app.logger import logger +from visualizr.app.types import InferenceType + +if TYPE_CHECKING: + from moviepy.video.compositing.CompositeVideoClip import CompositeVideoClip + from moviepy.video.VideoClip import VideoClip def frames_to_video( @@ -40,13 +45,13 @@ def frames_to_video( FileNotFoundError: If no frames are found. OSError: On I/O errors while reading/writing media. """ - clips = [ + clips: list[ImageClip] = [ ImageClip(m.as_posix()).set_duration(1 / fps) for m in sorted(input_path.iterdir()) ] - video = concatenate_videoclips(clips, "compose") - audio = AudioFileClip(audio_path) - final_video = video.set_audio(audio) + video: VideoClip | CompositeVideoClip = concatenate_videoclips(clips, "compose") + audio: AudioFileClip[Path] = AudioFileClip(audio_path) + final_video: VideoClip = video.set_audio(audio) final_video.write_videofile( output_path.as_posix(), fps, @@ -56,8 +61,8 @@ def frames_to_video( def load_image(img_path: Path, size: int) -> ndarray: - img: Image.Image = Image.open(img_path).convert("RGB") - img_resized: Image.Image = img.resize((size, size)) + img: Image = pil_open(img_path).convert("RGB") + img_resized: Image = img.resize((size, size)) img_np: ndarray = asarray(img_resized) img_transposed: ndarray = transpose(img_np, (2, 0, 1)) # 3 x 256 x 256 return img_transposed / 255.0 @@ -72,13 +77,15 @@ def img_preprocessing(img_path: Path, size: int) -> Tensor: def saved_image(img_tensor: Tensor, img_path: Path) -> None: pil_image_converter: ToPILImage = ToPILImage() - img = pil_image_converter(img_tensor.detach().cpu().squeeze(0)) + img: Image = pil_image_converter(img_tensor.detach().cpu().squeeze(0)) img.save(img_path) def remove_frames(frames_path: Path) -> None: try: - _msg = f"Deleting {len(list(frames_path.iterdir()))} frames at {frames_path}" + _msg: str = ( + f"Deleting {len(list(frames_path.iterdir()))} frames at {frames_path}" + ) logger.info(_msg) Info(_msg) for frame in frames_path.iterdir(): @@ -87,23 +94,23 @@ def remove_frames(frames_path: Path) -> None: logger.info(_msg) Info(_msg) except OSError as e: - _msg = f"Failed to delete frames: {e}" + _msg: str = f"Failed to delete frames: {e}" logger.exception(_msg) Error(_msg) def load_stage_2_model(conf: TrainConfig, stage_2_checkpoint_path: Path) -> LitModel: - _msg = f"Stage 2 checkpoint path: {stage_2_checkpoint_path}" + _msg: str = f"Stage 2 checkpoint path: {stage_2_checkpoint_path}" logger.info(_msg) Info(_msg) if not stage_2_checkpoint_path.exists(): - msg = f"Checkpoint not found: {stage_2_checkpoint_path}" + msg: str = f"Checkpoint not found: {stage_2_checkpoint_path}" raise FileNotFoundError(msg) model: LitModel = LitModel(conf) try: state = torch_load(stage_2_checkpoint_path, map_location="cpu") except Exception as e: - msg = f"Failed to load checkpoint: {e}" + msg: str = f"Failed to load checkpoint: {e}" raise RuntimeError(msg) from e model.load_state_dict(state) model.ema_model.eval() @@ -124,13 +131,7 @@ def _init_configuration_param( def init_configuration( - infer_type: Literal[ - "mfcc_full_control", - "mfcc_pose_only", - "hubert_pose_only", - "hubert_audio_only", - "hubert_full_control", - ], + infer_type: InferenceType, seed: int, decoder_layers: int, motion_dim: int, @@ -143,7 +144,7 @@ def init_configuration( conf.decoder_layers = decoder_layers conf.motion_dim = motion_dim conf.infer_type = infer_type - _msg = f"infer_type: {infer_type}" + _msg: str = f"infer_type: {infer_type}" logger.info(_msg) Info(_msg) match infer_type: @@ -163,8 +164,8 @@ def super_resolution( tmp_predicted_video_512_path: Path, predicted_video_256_path: Path, predicted_video_512_path: Path, -): - _msg = f"Saving video at {tmp_predicted_video_512_path}" +) -> None: + _msg: str = f"Saving video at {tmp_predicted_video_512_path}" logger.info(_msg) Info(_msg) mimsave( @@ -176,9 +177,13 @@ def super_resolution( fps=25.0, ) # Merge audio and video - video_clip = VideoFileClip(tmp_predicted_video_512_path.as_posix()) - audio_clip = AudioFileClip(predicted_video_256_path.as_posix()) - final_clip = video_clip.set_audio(audio_clip) + video_clip: VideoFileClip[str] = VideoFileClip( + tmp_predicted_video_512_path.as_posix(), + ) + audio_clip: AudioFileClip[str] = AudioFileClip( + predicted_video_256_path.as_posix(), + ) + final_clip: VideoClip = video_clip.set_audio(audio_clip) final_clip.write_videofile( predicted_video_512_path.as_posix(), codec="libx264", diff --git a/src/visualizr/app/builder.py b/src/visualizr/app/builder.py index 744b3bcb..ff3d5ee9 100644 --- a/src/visualizr/app/builder.py +++ b/src/visualizr/app/builder.py @@ -108,6 +108,9 @@ def generate_video( self.settings.model.motion_dim, ) + if not isinstance(image_path, Path): + image_path = Path(image_path) + img_source: Tensor = img_preprocessing(image_path, 256).to("cuda") one_shot_lia_start, one_shot_lia_direction, feats = ( lia.get_start_direction_code( @@ -275,6 +278,9 @@ def generate_video( logger.info(_msg) Info(_msg) + if not isinstance(audio_path, Path): + audio_path = Path(audio_path) + frames_to_video( self.settings.directory.frames, audio_path, @@ -344,10 +350,12 @@ def generate_video_from_name( _msg = f"Character '{name}' not found." logger.error(_msg) raise Error(_msg) - if isinstance(audio_file, str) and audio_file.startswith(( - "http://", - "https://", - )): + if isinstance(audio_file, str) and audio_file.startswith( + ( + "http://", + "https://", + ), + ): audio_file = self._download_audio(URL(audio_file)) if not isinstance(audio_file, Path): audio_file = Path(audio_file) @@ -400,12 +408,12 @@ def generate_video_mcp( seed (int): Random seed for reproducibility. Returns: - Path: A path to the generated video file. + str: A path to the generated video file. """ return self.generate_video_from_name( name, infer_type, - audio_file.as_posix(), + audio_file, pose_yaw, pose_pitch, pose_roll, diff --git a/src/visualizr/app/settings.py b/src/visualizr/app/settings.py index 8cca0d4b..937aed2f 100644 --- a/src/visualizr/app/settings.py +++ b/src/visualizr/app/settings.py @@ -93,8 +93,8 @@ class ModelSettings(BaseModel): step_t: int = 50 seed: int = 0 motion_dim: int = 20 - image_path: FilePath = Field(default=None) - audio_path: FilePath = Field(default=None) + image_path: FilePath | None = Field(default=None) + audio_path: FilePath | None = Field(default=None) control_flag: bool = True pose_driven_path: str = "not_supported_in_this_mode" image_size: int = 256 @@ -107,7 +107,18 @@ class ModelSettings(BaseModel): checkpoint: Checkpoint = Checkpoint() @model_validator(mode="after") - def check_image_path(self) -> "ModelSettings": + def check_missing_paths(self) -> "ModelSettings": + """ + Validate that the image and audio paths exist if provided. + + Checks the existence of the image_path and audio_path attributes. + + Returns: + ModelSettings: The validated ModelSettings instance. + + Raises: + FileNotFoundError: If the image_path or audio_path does not exist. + """ if self.image_path and not self.image_path.exists(): _msg = f"Image path does not exist: {self.image_path}" logger.error(_msg)