|
1354 | 1354 | ), |
1355 | 1355 | ), |
1356 | 1356 |
|
1357 | | - 'pow': dict( |
| 1357 | + 'pow_scalar_base_float_exp': dict( |
1358 | 1358 | name=['pow'], |
1359 | 1359 | interface=['torch'], |
1360 | 1360 | is_inplace=True, |
|
1376 | 1376 | ), |
1377 | 1377 | ), |
1378 | 1378 |
|
1379 | | - 'pow_int': dict( |
| 1379 | + 'pow_scalar_base_int_exp': dict( |
1380 | 1380 | name=['pow'], |
1381 | 1381 | interface=['torch'], |
1382 | 1382 | is_inplace=True, |
|
1391 | 1391 | (2, 128, 3072), (2, 512, 38, 38), |
1392 | 1392 | (0,), (0, 8), (7, 0, 9)), |
1393 | 1393 | "dtype": [np.int16, np.int32, np.int64, |
1394 | | - np.int8, np.uint8], |
1395 | | - "gen_fn": dict(fn='Genfunc.randint', low=-4, high=4), |
| 1394 | + np.int8, np.uint8, np.bool_], |
| 1395 | + "gen_fn": dict(fn='Genfunc.uniform', low=-4, high=4), |
1396 | 1396 | } |
1397 | 1397 | ], |
1398 | 1398 | ), |
1399 | 1399 | ), |
1400 | 1400 |
|
1401 | | - 'pow_bool': dict( |
| 1401 | + # attention: Integers to negative integer powers are not allowed. |
| 1402 | + # may cause overflow if both base and exponet are uint8. |
| 1403 | + # int zero to negative int exp powers are not defined. |
| 1404 | + 'pow_tensor_base_positive_exp': dict( |
1402 | 1405 | name=['pow'], |
1403 | 1406 | interface=['torch'], |
1404 | 1407 | is_inplace=True, |
1405 | | - para=dict( |
1406 | | - exponent=[0, -1.2, 2, 0.6, 1.2, 0.], |
1407 | | - ), |
| 1408 | + dtype=[np.float16, np.float32, np.float64, |
| 1409 | + np.int16, np.int32, np.int64, |
| 1410 | + np.int8], |
1408 | 1411 | tensor_para=dict( |
1409 | 1412 | args=[ |
1410 | 1413 | { |
1411 | 1414 | "ins": ['input'], |
1412 | | - "shape": ((), (20267, 80), |
| 1415 | + "shape": ((), (1, ), (20267, 80), |
1413 | 1416 | (2, 128, 3072), |
1414 | 1417 | (2, 512, 38, 38), |
1415 | | - (0,), (0, 8)), |
1416 | | - "dtype": [np.bool_], |
1417 | | - "gen_fn": 'Genfunc.mask', |
1418 | | - } |
| 1418 | + (0,), (0, 4), (9, 0, 3)), |
| 1419 | + "gen_fn": dict(fn='Genfunc.uniform', low=-4, high=4), |
| 1420 | + }, |
| 1421 | + { |
| 1422 | + "ins": ['exponent'], |
| 1423 | + "shape": ((), (1, ), (20267, 80), |
| 1424 | + (2, 128, 3072), |
| 1425 | + (2, 512, 38, 38), |
| 1426 | + (0,), (0, 4), (9, 0, 3)), |
| 1427 | + "gen_fn": dict(fn='Genfunc.uniform', low=1, high=4), |
| 1428 | + }, |
1419 | 1429 | ], |
1420 | 1430 | ), |
1421 | 1431 | ), |
1422 | 1432 |
|
1423 | | - 'pow_tensor': dict( |
| 1433 | + # attention: Integers to negative integer powers are not allowed. |
| 1434 | + # int zero to negative int exp powers are not defined. |
| 1435 | + 'pow_tensor_base_negative_exp': dict( |
1424 | 1436 | name=['pow'], |
1425 | 1437 | interface=['torch'], |
1426 | 1438 | is_inplace=True, |
1427 | | - dtype=[np.float16, np.float32, np.float64, |
1428 | | - np.int16, np.int32, np.int64, |
1429 | | - np.int8, np.uint8], |
| 1439 | + dtype=[np.float16, np.float32, np.float64], |
1430 | 1440 | tensor_para=dict( |
1431 | | - gen_fn=dict(fn='Genfunc.randn_int', low=-4, high=4), |
1432 | 1441 | args=[ |
1433 | 1442 | { |
1434 | 1443 | "ins": ['input'], |
1435 | 1444 | "shape": ((), (1, ), (20267, 80), |
1436 | 1445 | (2, 128, 3072), |
1437 | 1446 | (2, 512, 38, 38), |
1438 | 1447 | (0,), (0, 4), (9, 0, 3)), |
| 1448 | + "gen_fn": dict(fn='Genfunc.uniform', low=-4, high=4), |
1439 | 1449 | }, |
1440 | 1450 | { |
1441 | 1451 | "ins": ['exponent'], |
1442 | 1452 | "shape": ((), (1, ), (20267, 80), |
1443 | 1453 | (2, 128, 3072), |
1444 | 1454 | (2, 512, 38, 38), |
1445 | 1455 | (0,), (0, 4), (9, 0, 3)), |
| 1456 | + "gen_fn": dict(fn='Genfunc.uniform', low=-4, high=-1), |
1446 | 1457 | }, |
1447 | 1458 | ], |
1448 | 1459 | ), |
1449 | 1460 | ), |
1450 | 1461 |
|
| 1462 | + # int zero to negative int exp powers are not defined. |
1451 | 1463 | 'pow_tensor_only_0_1': dict( |
1452 | 1464 | name=['pow'], |
1453 | 1465 | interface=['torch'], |
1454 | 1466 | is_inplace=True, |
1455 | 1467 | dtype=[np.int16, np.int32, np.int64, |
1456 | 1468 | np.int8, np.uint8], |
1457 | 1469 | tensor_para=dict( |
1458 | | - gen_fn='Genfunc.randn', |
| 1470 | + gen_fn=dict(fn='Genfunc.uniform', low=0, high=2), |
1459 | 1471 | args=[ |
1460 | 1472 | { |
1461 | 1473 | "ins": ['input'], |
|
1520 | 1532 | ), |
1521 | 1533 | ), |
1522 | 1534 |
|
| 1535 | + # attention: Integers to negative integer powers are not allowed. |
| 1536 | + # may cause overflow if both base and exponet are uint8 |
| 1537 | + # int zero to negative int exp powers are not defined. |
1523 | 1538 | 'pow_diff_dtype_cast': dict( |
1524 | 1539 | name=['pow'], |
1525 | 1540 | interface=['torch'], |
1526 | 1541 | tensor_para=dict( |
1527 | | - gen_fn=dict(fn='Genfunc.randn_int', low=-4, high=4), |
1528 | 1542 | args=[ |
1529 | 1543 | { |
1530 | 1544 | "ins": ['input'], |
1531 | 1545 | "shape": ((1024, ),), |
1532 | 1546 | "dtype": [np.int64, np.int32, np.int16, |
1533 | 1547 | np.bool_, np.bool_, np.bool_, np.bool_], |
| 1548 | + "gen_fn": dict(fn='Genfunc.uniform', low=-4, high=4), |
1534 | 1549 | }, |
1535 | 1550 | { |
1536 | 1551 | "ins": ['exponent'], |
1537 | 1552 | "shape": ((1024, ),), |
1538 | 1553 | "dtype": [np.float32, np.float64, np.float16, |
1539 | 1554 | np.int32, np.float32, np.int8, np.uint8], |
| 1555 | + "gen_fn": dict(fn='Genfunc.uniform', low=1, high=4), |
1540 | 1556 | }, |
1541 | 1557 | ], |
1542 | 1558 | ), |
1543 | 1559 | ), |
1544 | 1560 |
|
1545 | | - # FIXME pow的input与exponent输入uint8和int8,结果不一致 |
| 1561 | + # attention: Integers to negative integer powers are not allowed. |
| 1562 | + # may cause overflow if both base and exponet are uint8 |
| 1563 | + # int zero to negative int exp powers are not defined. |
1546 | 1564 | 'pow_diff_dtype': dict( |
1547 | 1565 | name=['pow'], |
1548 | 1566 | interface=['torch'], |
1549 | 1567 | is_inplace=True, |
1550 | 1568 | tensor_para=dict( |
1551 | | - gen_fn=dict(fn='Genfunc.randn_int', low=-4, high=4), |
1552 | 1569 | args=[ |
1553 | 1570 | { |
1554 | 1571 | "ins": ['input'], |
1555 | 1572 | "shape": ((1024, ),), |
1556 | | - # "dtype":[np.float64, np.float32, np.float16, |
1557 | | - # np.int32, np.float64, np.float64, |
1558 | | - # np.int8, np.float32, np.uint8], |
1559 | 1573 | "dtype": [np.float64, np.float32, np.float16, |
1560 | 1574 | np.int32, np.float64, np.float32, |
1561 | 1575 | np.float32, np.int16, np.int64], |
| 1576 | + "gen_fn": dict(fn='Genfunc.uniform', low=-4, high=4), |
1562 | 1577 | }, |
1563 | 1578 | { |
1564 | 1579 | "ins": ['exponent'], |
1565 | 1580 | "shape": ((1024, ),), |
1566 | | - # "dtype":[np.int32, np.uint8, np.bool_, |
1567 | | - # np.int64, np.float16, np.float32, |
1568 | | - # np.uint8, np.bool_, np.int8], |
1569 | 1581 | "dtype": [np.int32, np.uint8, np.bool_, |
1570 | 1582 | np.int64, np.float16, np.float64, |
1571 | 1583 | np.bool_, np.uint8, np.bool_], |
| 1584 | + "gen_fn": dict(fn='Genfunc.uniform', low=1, high=4), |
1572 | 1585 | }, |
1573 | 1586 | ], |
1574 | 1587 | ), |
1575 | 1588 | ), |
1576 | 1589 |
|
1577 | | - 'pow_input_scalar': dict( |
| 1590 | + # attention: Integers to negative integer powers are not allowed. |
| 1591 | + # may cause overflow if exponet are uint8 |
| 1592 | + # int zero to negative int exp powers are not defined. |
| 1593 | + 'pow_input_scalar_positive_exp': dict( |
1578 | 1594 | name=['pow'], |
1579 | 1595 | interface=['torch'], |
1580 | 1596 | para=dict( |
|
1589 | 1605 | (0,), (0, 4), (9, 0, 6)), |
1590 | 1606 | "dtype": [np.float16, np.float32, np.float64, |
1591 | 1607 | np.int16, np.int32, np.int64, |
1592 | | - np.int8, np.uint8, np.bool_], |
1593 | | - "gen_fn": dict(fn='Genfunc.randn_int', low=-4, high=4), |
| 1608 | + np.int8, np.bool_], |
| 1609 | + "gen_fn": dict(fn='Genfunc.uniform', low=1, high=4), |
1594 | 1610 | } |
1595 | 1611 | ], |
1596 | 1612 | ), |
1597 | 1613 | ), |
1598 | 1614 |
|
| 1615 | + 'pow_input_scalar_negative_exp': dict( |
| 1616 | + name=['pow'], |
| 1617 | + interface=['torch'], |
| 1618 | + para=dict( |
| 1619 | + self=[-2, -0.5, 0, 0.6, 2, 3, 4., 1.], |
| 1620 | + ), |
| 1621 | + tensor_para=dict( |
| 1622 | + args=[ |
| 1623 | + { |
| 1624 | + "ins": ['exponent'], |
| 1625 | + "shape": ((), (8,), (125, 1), |
| 1626 | + (70, 1, 2), (4, 256, 16, 16), |
| 1627 | + (0,), (0, 4), (9, 0, 6)), |
| 1628 | + "dtype": [np.float16, np.float32, np.float64], |
| 1629 | + "gen_fn": dict(fn='Genfunc.uniform', low=-4, high=-1), |
| 1630 | + } |
| 1631 | + ], |
| 1632 | + ), |
| 1633 | + ), |
| 1634 | + |
| 1635 | + # attention: Integers to negative integer powers are not allowed. |
1599 | 1636 | 'pow_input_scalar_bool': dict( |
1600 | 1637 | name=['pow'], |
1601 | 1638 | interface=['torch'], |
|
1610 | 1647 | "dtype": [np.float16, np.float32, np.float64, |
1611 | 1648 | np.int16, np.int32, np.int64, |
1612 | 1649 | np.int8, np.uint8], |
1613 | | - "gen_fn": dict(fn='Genfunc.randn_int', low=-4, high=4), |
| 1650 | + "gen_fn": dict(fn='Genfunc.uniform', low=1, high=4), |
1614 | 1651 | } |
1615 | 1652 | ], |
1616 | 1653 | ), |
|
0 commit comments