Skip to content

Commit a544afc

Browse files
committed
Updated kernel to correctly use NaNs in window function.
1 parent a3058c2 commit a544afc

File tree

3 files changed

+233
-67
lines changed

3 files changed

+233
-67
lines changed

bigwig_loader/intervals_to_values.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import math
3+
from math import isnan
34
from pathlib import Path
45

56
import cupy as cp
@@ -120,6 +121,7 @@ def intervals_to_values(
120121
array_start = cp.ascontiguousarray(array_start)
121122
array_end = cp.ascontiguousarray(array_end)
122123
array_value = cp.ascontiguousarray(array_value)
124+
default_value_isnan = isnan(default_value)
123125

124126
cuda_kernel(
125127
(grid_size,),
@@ -137,6 +139,8 @@ def intervals_to_values(
137139
sequence_length,
138140
max_number_intervals,
139141
window_size,
142+
default_value,
143+
default_value_isnan,
140144
out,
141145
),
142146
)

cuda_kernels/intervals_to_values.cu

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include <math_constants.h>
2+
13
extern "C" __global__
24
void intervals_to_values(
35
const unsigned int* query_starts,
@@ -12,6 +14,8 @@ void intervals_to_values(
1214
const int sequence_length,
1315
const int max_number_intervals,
1416
const int window_size,
17+
const float default_value,
18+
const bool default_value_isnan,
1519
float* out
1620
) {
1721

@@ -49,7 +53,7 @@ void intervals_to_values(
4953
}
5054
} else {
5155

52-
int track_index = i / batch_size;
56+
// int track_index = i / batch_size;
5357

5458
int found_start_index = found_starts[i];
5559
int found_end_index = found_ends[i];
@@ -59,6 +63,8 @@ void intervals_to_values(
5963
int cursor = found_start_index;
6064
int window_index = 0;
6165
float summation = 0.0f;
66+
int valid_count = 0;
67+
6268

6369
int reduced_dim = sequence_length / window_size;
6470

@@ -73,19 +79,34 @@ void intervals_to_values(
7379
int end_index = min(interval_end, query_end) - query_start;
7480

7581
if (start_index >= window_end) {
76-
out[i * reduced_dim + window_index] = summation / window_size;
82+
if (default_value_isnan) {
83+
out[i * reduced_dim + window_index] = valid_count > 0 ? summation / valid_count : CUDART_NAN_F;
84+
} else {
85+
summation = summation + (window_size - valid_count) * default_value;
86+
out[i * reduced_dim + window_index] = summation / window_size;
87+
}
7788
summation = 0.0f;
89+
valid_count = 0;
7890
window_index += 1;
7991
continue;
8092
}
8193

8294
int number = min(window_end, end_index) - max(window_start, start_index);
8395

84-
summation += number * track_values[cursor];
96+
if (number > 0) {
97+
summation += number * track_values[cursor];
98+
valid_count += number;
99+
}
85100

86101
if (end_index >= window_end || cursor + 1 >= found_end_index) {
87-
out[i * reduced_dim + window_index] = summation / window_size;
88-
summation = 0.0f;
102+
if (default_value_isnan) {
103+
out[i * reduced_dim + window_index] = valid_count > 0 ? summation / valid_count : CUDART_NAN_F;
104+
} else {
105+
summation = summation + (window_size - valid_count) * default_value;
106+
out[i * reduced_dim + window_index] = summation / window_size;
107+
}
108+
summation = 0.0f;
109+
valid_count = 0;
89110
window_index += 1;
90111
}
91112

0 commit comments

Comments
 (0)