Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,8 @@ void LstmStepManager::UpdateBatch() {
// Multi-batch for time_major input
RuntimeShape LstmStepManager::InputShape() const {
int batch_size = 1;
if (size_info_.time_major) {
if (size_info_.time_major ||
(size_info_.batch_size > 1 && size_info_.time_steps == 1)) {
batch_size = size_info_.batch_size;
}
const int dims[2] = {batch_size, size_info_.input_dimension};
Expand All @@ -485,7 +486,8 @@ RuntimeShape LstmStepManager::InputShape() const {
// Multi-batch for time_major input
RuntimeShape LstmStepManager::StateShape() const {
int batch_size = 1;
if (size_info_.time_major) {
if (size_info_.time_major ||
(size_info_.batch_size > 1 && size_info_.time_steps == 1)) {
batch_size = size_info_.batch_size;
}
const int dims[2] = {batch_size, size_info_.state_dimension};
Expand Down
17 changes: 12 additions & 5 deletions tensorflow/lite/micro/kernels/xtensa/lstm_eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -661,10 +661,14 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data,
kernel_content.GetInternalTensor(tflite::kLstmInputTensor);
TfLiteEvalTensor* recurrent = kernel_content.HiddenStateTensor();

int time_major = step_info.time_major();
int num_batches = time_major == 0 ? 1 : step_info.batch_size();
int input_dimension = step_info.input_dimension();
int state_dimension = step_info.state_dimension();
const auto& size_info = op_data.size_info;
const int time_major = step_info.time_major();
const int batch_size = size_info.batch_size;
const int time_steps = size_info.time_steps;
const int num_batches = time_major == 0 ? (time_steps == 1 ? batch_size : 1)
: step_info.batch_size();
const int input_dimension = step_info.input_dimension();
const int state_dimension = step_info.state_dimension();

// Check offset validity to avoid memory overflow
TFLITE_DCHECK_LE(step_info.InputOffset() + num_batches * input_dimension,
Expand Down Expand Up @@ -803,8 +807,11 @@ TfLiteStatus EvalLstm(const OpDataLSTM& op_data,
// prepare for the next time step
step_info.UpdateTime();
}
} else if (size_info.batch_size > 1 && size_info.time_steps == 1) {
// Ramesh
lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
step_info, op_data, kernel_content, buffers);
} else {
// batch first, unable to size the input data. single batch inference
for (int b = 0; b < size_info.batch_size; b++) {
for (int t = 0; t < size_info.time_steps; t++) {
lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
Expand Down
Loading