Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vectorbt/indicators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def talib(*args, **kwargs) -> tp.Type[IndicatorBase]:
"""Shortcut for `vectorbt.indicators.factory.IndicatorFactory.from_talib`."""
return IndicatorFactory.from_talib(*args, **kwargs)

def mtalib(*args, **kwargs) -> tp.Type[IndicatorBase]:
"""Shortcut for `vectorbt.indicators.factory.IndicatorFactory.from_talib`."""
return IndicatorFactory.from_mtalib(*args, **kwargs)

def pandas_ta(*args, **kwargs) -> tp.Type[IndicatorBase]:
"""Shortcut for `vectorbt.indicators.factory.IndicatorFactory.from_pandas_ta`."""
Expand All @@ -38,6 +41,7 @@ def ta(*args, **kwargs) -> tp.Type[IndicatorBase]:
__all__ = [
'IndicatorFactory',
'talib',
'mtalib',
'pandas_ta',
'ta',
'MA',
Expand All @@ -50,6 +54,7 @@ def ta(*args, **kwargs) -> tp.Type[IndicatorBase]:
'OBV'
]
__whitelist__ = [
'mtalib',
'talib',
'pandas_ta',
'ta'
Expand Down
130 changes: 130 additions & 0 deletions vectorbt/indicators/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3490,6 +3490,136 @@ def apply_func(input_list: tp.List[tp.AnyArray],
)
return TALibIndicator


@classmethod
def from_mtalib(cls, func_name: str, timescale:int, init_kwargs: tp.KwargsLike = None, **kwargs) -> tp.Type[IndicatorBase]:
"""Build an indicator class around a TA-Lib function.

Requires [TA-Lib](https://github.com/mrjbq7/ta-lib) installed.

For input, parameter and output names, see [docs](https://github.com/mrjbq7/ta-lib/blob/master/docs/index.md).

Args:
func_name (str): Function name.
timescale:
init_kwargs (dict): Keyword arguments passed to `IndicatorFactory`.
**kwargs: Keyword arguments passed to `IndicatorFactory.from_custom_func`.

Returns:
Indicator

Usage:
```pycon
>>> SMA = vbt.IndicatorFactory.from_talib('SMA')

>>> sma = SMA.run(price, timeperiod=[2, 3])
>>> sma.real
sma_timeperiod 2 3
a b a b
2020-01-01 NaN NaN NaN NaN
2020-01-02 1.5 4.5 NaN NaN
2020-01-03 2.5 3.5 2.0 4.0
2020-01-04 3.5 2.5 3.0 3.0
2020-01-05 4.5 1.5 4.0 2.0
```

* To get help on running the indicator, use the `help` command:

```pycon
>>> help(SMA.run)
Help on method run:

run(close, timeperiod=30, short_name='sma', hide_params=None, hide_default=True, **kwargs) method of builtins.type instance
Run `SMA` indicator.

* Inputs: `close`
* Parameters: `timeperiod`
* Outputs: `real`

Pass a list of parameter names as `hide_params` to hide their column levels.
Set `hide_default` to False to show the column levels of the parameters with a default value.

Other keyword arguments are passed to `vectorbt.indicators.factory.run_pipeline`.
```
"""
import talib
from talib import abstract

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we importing talib here? They should already be available


func_name = func_name.upper()
talib_func = getattr(talib, func_name)
info = abstract.Function(func_name).info
input_names = []
for in_names in info['input_names'].values():
if isinstance(in_names, (list, tuple)):
input_names.extend(list(in_names))
else:
input_names.append(in_names)
class_name = info['name']
class_docstring = "{}, {}".format(info['display_name'], info['group'])
param_names = list(info['parameters'].keys())
output_names = info['output_names']
output_flags = info['output_flags']

def apply_func(input_list: tp.List[tp.AnyArray],
in_output_tuple: tp.Tuple[tp.AnyArray, ...],
param_tuple: tp.Tuple[tp.Param, ...],
**kwargs) -> tp.Union[tp.Array2d, tp.List[tp.Array2d]]:

# TA-Lib functions can only process 1-dim arrays
n_input_cols = input_list[0].shape[1]
outputs = []

# 增加 timescale 支持
# 保证 是 timescale的倍数
n = input_list[0].shape[0]
start = n % timescale
if start:
n = n - start + timescale

for col in range(n_input_cols):

s = np.full(n, fill_value=np.nan).reshape(-1, timescale)
for i in range(timescale):
s[1 if start else 0:,i] = talib_func(
*map(lambda x: x[start+i::timescale, col], input_list),
*param_tuple,
**kwargs
)

if start:
# 有start 会跳过第一个不完整的stride
output = s.reshape(n,)[timescale-start:]
else:
# 没有则不跳
output = s.reshape(n, )

outputs.append(output)
if isinstance(outputs[0], tuple): # multiple outputs
outputs = list(zip(*outputs))
return list(map(np.column_stack, outputs))
return np.column_stack(outputs)

TALibIndicator = cls(
**merge_dicts(
dict(
class_name=class_name,
class_docstring=class_docstring,
input_names=input_names,
param_names=param_names, # 这里增加 timescale, 就得在apply_func 里剔除
output_names=output_names,
output_flags=output_flags
),
init_kwargs
)
).from_apply_func(
apply_func,
pass_packed=True,
**info['parameters'], # 输入强制要求 kv 形式
**kwargs
)
return TALibIndicator


@classmethod
def parse_pandas_ta_config(cls,
func: tp.Callable,
Expand Down