diff --git a/vectorbt/indicators/__init__.py b/vectorbt/indicators/__init__.py index 9f734b14..bb618f7b 100644 --- a/vectorbt/indicators/__init__.py +++ b/vectorbt/indicators/__init__.py @@ -24,6 +24,9 @@ def talib(*args, **kwargs) -> tp.Type[IndicatorBase]: """Shortcut for `vectorbt.indicators.factory.IndicatorFactory.from_talib`.""" return IndicatorFactory.from_talib(*args, **kwargs) +def mtalib(*args, **kwargs) -> tp.Type[IndicatorBase]: + """Shortcut for `vectorbt.indicators.factory.IndicatorFactory.from_talib`.""" + return IndicatorFactory.from_mtalib(*args, **kwargs) def pandas_ta(*args, **kwargs) -> tp.Type[IndicatorBase]: """Shortcut for `vectorbt.indicators.factory.IndicatorFactory.from_pandas_ta`.""" @@ -38,6 +41,7 @@ def ta(*args, **kwargs) -> tp.Type[IndicatorBase]: __all__ = [ 'IndicatorFactory', 'talib', + 'mtalib', 'pandas_ta', 'ta', 'MA', @@ -50,6 +54,7 @@ def ta(*args, **kwargs) -> tp.Type[IndicatorBase]: 'OBV' ] __whitelist__ = [ + 'mtalib', 'talib', 'pandas_ta', 'ta' diff --git a/vectorbt/indicators/factory.py b/vectorbt/indicators/factory.py index 5948eef0..75451425 100644 --- a/vectorbt/indicators/factory.py +++ b/vectorbt/indicators/factory.py @@ -3490,6 +3490,136 @@ def apply_func(input_list: tp.List[tp.AnyArray], ) return TALibIndicator + + @classmethod + def from_mtalib(cls, func_name: str, timescale:int, init_kwargs: tp.KwargsLike = None, **kwargs) -> tp.Type[IndicatorBase]: + """Build an indicator class around a TA-Lib function. + + Requires [TA-Lib](https://github.com/mrjbq7/ta-lib) installed. + + For input, parameter and output names, see [docs](https://github.com/mrjbq7/ta-lib/blob/master/docs/index.md). + + Args: + func_name (str): Function name. + timescale: + init_kwargs (dict): Keyword arguments passed to `IndicatorFactory`. + **kwargs: Keyword arguments passed to `IndicatorFactory.from_custom_func`. + + Returns: + Indicator + + Usage: + ```pycon + >>> SMA = vbt.IndicatorFactory.from_talib('SMA') + + >>> sma = SMA.run(price, timeperiod=[2, 3]) + >>> sma.real + sma_timeperiod 2 3 + a b a b + 2020-01-01 NaN NaN NaN NaN + 2020-01-02 1.5 4.5 NaN NaN + 2020-01-03 2.5 3.5 2.0 4.0 + 2020-01-04 3.5 2.5 3.0 3.0 + 2020-01-05 4.5 1.5 4.0 2.0 + ``` + + * To get help on running the indicator, use the `help` command: + + ```pycon + >>> help(SMA.run) + Help on method run: + + run(close, timeperiod=30, short_name='sma', hide_params=None, hide_default=True, **kwargs) method of builtins.type instance + Run `SMA` indicator. + + * Inputs: `close` + * Parameters: `timeperiod` + * Outputs: `real` + + Pass a list of parameter names as `hide_params` to hide their column levels. + Set `hide_default` to False to show the column levels of the parameters with a default value. + + Other keyword arguments are passed to `vectorbt.indicators.factory.run_pipeline`. + ``` + """ + import talib + from talib import abstract + + func_name = func_name.upper() + talib_func = getattr(talib, func_name) + info = abstract.Function(func_name).info + input_names = [] + for in_names in info['input_names'].values(): + if isinstance(in_names, (list, tuple)): + input_names.extend(list(in_names)) + else: + input_names.append(in_names) + class_name = info['name'] + class_docstring = "{}, {}".format(info['display_name'], info['group']) + param_names = list(info['parameters'].keys()) + output_names = info['output_names'] + output_flags = info['output_flags'] + + def apply_func(input_list: tp.List[tp.AnyArray], + in_output_tuple: tp.Tuple[tp.AnyArray, ...], + param_tuple: tp.Tuple[tp.Param, ...], + **kwargs) -> tp.Union[tp.Array2d, tp.List[tp.Array2d]]: + + # TA-Lib functions can only process 1-dim arrays + n_input_cols = input_list[0].shape[1] + outputs = [] + + # 增加 timescale 支持 + # 保证 是 timescale的倍数 + n = input_list[0].shape[0] + start = n % timescale + if start: + n = n - start + timescale + + for col in range(n_input_cols): + + s = np.full(n, fill_value=np.nan).reshape(-1, timescale) + for i in range(timescale): + s[1 if start else 0:,i] = talib_func( + *map(lambda x: x[start+i::timescale, col], input_list), + *param_tuple, + **kwargs + ) + + if start: + # 有start 会跳过第一个不完整的stride + output = s.reshape(n,)[timescale-start:] + else: + # 没有则不跳 + output = s.reshape(n, ) + + outputs.append(output) + if isinstance(outputs[0], tuple): # multiple outputs + outputs = list(zip(*outputs)) + return list(map(np.column_stack, outputs)) + return np.column_stack(outputs) + + TALibIndicator = cls( + **merge_dicts( + dict( + class_name=class_name, + class_docstring=class_docstring, + input_names=input_names, + param_names=param_names, # 这里增加 timescale, 就得在apply_func 里剔除 + output_names=output_names, + output_flags=output_flags + ), + init_kwargs + ) + ).from_apply_func( + apply_func, + pass_packed=True, + **info['parameters'], # 输入强制要求 kv 形式 + **kwargs + ) + return TALibIndicator + + @classmethod def parse_pandas_ta_config(cls, func: tp.Callable,