diff --git a/src/main/java/com/databricks/jdbc/api/impl/ExecutionResultFactory.java b/src/main/java/com/databricks/jdbc/api/impl/ExecutionResultFactory.java index 4c719731d..ba6d7acb7 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/ExecutionResultFactory.java +++ b/src/main/java/com/databricks/jdbc/api/impl/ExecutionResultFactory.java @@ -1,6 +1,7 @@ package com.databricks.jdbc.api.impl; import com.databricks.jdbc.api.impl.arrow.ArrowStreamResult; +import com.databricks.jdbc.api.impl.arrow.LazyThriftInlineArrowResult; import com.databricks.jdbc.api.impl.volume.VolumeOperationResult; import com.databricks.jdbc.api.internal.IDatabricksSession; import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; @@ -96,9 +97,9 @@ private static IExecutionResult getResultHandler( case COLUMN_BASED_SET: return new LazyThriftResult(resultsResp, parentStatement, session); case ARROW_BASED_SET: - return new ArrowStreamResult(resultsResp, true, parentStatement, session); + return new LazyThriftInlineArrowResult(resultsResp, parentStatement, session); case URL_BASED_SET: - return new ArrowStreamResult(resultsResp, false, parentStatement, session); + return new ArrowStreamResult(resultsResp, parentStatement, session); case ROW_BASED_SET: throw new DatabricksSQLFeatureNotSupportedException( "Invalid state - row based set cannot be received"); diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResult.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResult.java index 29a88fd6b..4e011301e 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResult.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResult.java @@ -85,13 +85,11 @@ public ArrowStreamResult( public ArrowStreamResult( TFetchResultsResp resultsResp, - boolean isInlineArrow, IDatabricksStatementInternal parentStatementId, IDatabricksSession session) throws DatabricksSQLException { this( resultsResp, - isInlineArrow, parentStatementId, session, DatabricksHttpClientFactory.getInstance().getClient(session.getConnectionContext())); @@ -100,27 +98,22 @@ public ArrowStreamResult( @VisibleForTesting ArrowStreamResult( TFetchResultsResp resultsResp, - boolean isInlineArrow, IDatabricksStatementInternal parentStatement, IDatabricksSession session, IDatabricksHttpClient httpClient) throws DatabricksSQLException { this.session = session; setColumnInfo(resultsResp.getResultSetMetadata()); - if (isInlineArrow) { - this.chunkProvider = new InlineChunkProvider(resultsResp, parentStatement, session); - } else { - CompressionCodec compressionCodec = - CompressionCodec.getCompressionMapping(resultsResp.getResultSetMetadata()); - this.chunkProvider = - new RemoteChunkProvider( - parentStatement, - resultsResp, - session, - httpClient, - session.getConnectionContext().getCloudFetchThreadPoolSize(), - compressionCodec); - } + CompressionCodec compressionCodec = + CompressionCodec.getCompressionMapping(resultsResp.getResultSetMetadata()); + this.chunkProvider = + new RemoteChunkProvider( + parentStatement, + resultsResp, + session, + httpClient, + session.getConnectionContext().getCloudFetchThreadPoolSize(), + compressionCodec); } public List getArrowMetadata() throws DatabricksSQLException { @@ -133,30 +126,15 @@ public List getArrowMetadata() throws DatabricksSQLException { /** {@inheritDoc} */ @Override public Object getObject(int columnIndex) throws DatabricksSQLException { - ColumnInfoTypeName requiredType = columnInfos.get(columnIndex).getTypeName(); + ColumnInfo columnInfo = columnInfos.get(columnIndex); + ColumnInfoTypeName requiredType = columnInfo.getTypeName(); String arrowMetadata = chunkIterator.getType(columnIndex); if (arrowMetadata == null) { - arrowMetadata = columnInfos.get(columnIndex).getTypeText(); - } - - // Handle complex type conversion when complex datatype support is disabled - boolean isComplexDatatypeSupportEnabled = - this.session.getConnectionContext().isComplexDatatypeSupportEnabled(); - if (!isComplexDatatypeSupportEnabled && isComplexType(requiredType)) { - LOGGER.debug("Complex datatype support is disabled, converting complex type to STRING"); - - Object result = - chunkIterator.getColumnObjectAtCurrentRow( - columnIndex, ColumnInfoTypeName.STRING, "STRING", columnInfos.get(columnIndex)); - if (result == null) { - return null; - } - ComplexDataTypeParser parser = new ComplexDataTypeParser(); - return parser.formatComplexTypeString(result.toString(), requiredType.name(), arrowMetadata); + arrowMetadata = columnInfo.getTypeText(); } - return chunkIterator.getColumnObjectAtCurrentRow( - columnIndex, requiredType, arrowMetadata, columnInfos.get(columnIndex)); + return getObjectWithComplexTypeHandling( + session, chunkIterator, columnIndex, requiredType, arrowMetadata, columnInfo); } /** @@ -237,4 +215,44 @@ private void setColumnInfo(TGetResultSetMetadataResp resultManifest) { columnInfos.add(getColumnInfoFromTColumnDesc(tColumnDesc)); } } + + /** + * Helper method to handle complex type conversion when complex datatype support is disabled. + * + * @param session The databricks session + * @param chunkIterator The chunk iterator + * @param columnIndex The column index + * @param requiredType The required column type + * @param arrowMetadata The arrow metadata + * @param columnInfo The column info + * @return The object value (converted if complex type and support disabled) + * @throws DatabricksSQLException if an error occurs + */ + protected static Object getObjectWithComplexTypeHandling( + IDatabricksSession session, + ArrowResultChunkIterator chunkIterator, + int columnIndex, + ColumnInfoTypeName requiredType, + String arrowMetadata, + ColumnInfo columnInfo) + throws DatabricksSQLException { + boolean isComplexDatatypeSupportEnabled = + session.getConnectionContext().isComplexDatatypeSupportEnabled(); + + if (!isComplexDatatypeSupportEnabled && isComplexType(requiredType)) { + LOGGER.debug("Complex datatype support is disabled, converting complex type to STRING"); + Object result = + chunkIterator.getColumnObjectAtCurrentRow( + columnIndex, ColumnInfoTypeName.STRING, "STRING", columnInfo); + if (result == null) { + return null; + } + ComplexDataTypeParser parser = new ComplexDataTypeParser(); + + return parser.formatComplexTypeString(result.toString(), requiredType.name(), arrowMetadata); + } + + return chunkIterator.getColumnObjectAtCurrentRow( + columnIndex, requiredType, arrowMetadata, columnInfo); + } } diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProvider.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProvider.java index e22d974a4..32f5e1b80 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProvider.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProvider.java @@ -1,31 +1,17 @@ package com.databricks.jdbc.api.impl.arrow; -import static com.databricks.jdbc.common.util.DatabricksTypeUtil.*; import static com.databricks.jdbc.common.util.DecompressionUtil.decompress; -import com.databricks.jdbc.api.internal.IDatabricksSession; -import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; import com.databricks.jdbc.common.CompressionCodec; import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; import com.databricks.jdbc.log.JdbcLogger; import com.databricks.jdbc.log.JdbcLoggerFactory; -import com.databricks.jdbc.model.client.thrift.generated.*; import com.databricks.jdbc.model.core.ResultData; import com.databricks.jdbc.model.core.ResultManifest; import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode; import com.google.common.annotations.VisibleForTesting; import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.List; -import org.apache.arrow.vector.types.pojo.ArrowType; -import org.apache.arrow.vector.types.pojo.Field; -import org.apache.arrow.vector.types.pojo.FieldType; -import org.apache.arrow.vector.types.pojo.Schema; -import org.apache.arrow.vector.util.SchemaUtility; /** Class to manage inline Arrow chunks */ public class InlineChunkProvider implements ChunkProvider { @@ -37,23 +23,6 @@ public class InlineChunkProvider implements ChunkProvider { private final ArrowResultChunk arrowResultChunk; // There is only one packet of data in case of inline arrow - InlineChunkProvider( - TFetchResultsResp resultsResp, - IDatabricksStatementInternal parentStatement, - IDatabricksSession session) - throws DatabricksParsingException { - this.currentChunkIndex = -1; - this.totalRows = 0; - ByteArrayInputStream byteStream = initializeByteStream(resultsResp, session, parentStatement); - ArrowResultChunk.Builder builder = - ArrowResultChunk.builder().withInputStream(byteStream, totalRows); - - if (parentStatement != null) { - builder.withStatementId(parentStatement.getStatementId()); - } - arrowResultChunk = builder.build(); - } - /** * Constructor for inline arrow chunk provider from {@link ResultData} and {@link ResultManifest}. * @@ -123,97 +92,6 @@ public boolean isClosed() { return isClosed; } - private ByteArrayInputStream initializeByteStream( - TFetchResultsResp resultsResp, - IDatabricksSession session, - IDatabricksStatementInternal parentStatement) - throws DatabricksParsingException { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - CompressionCodec compressionType = - CompressionCodec.getCompressionMapping(resultsResp.getResultSetMetadata()); - try { - byte[] serializedSchema = getSerializedSchema(resultsResp.getResultSetMetadata()); - if (serializedSchema != null) { - baos.write(serializedSchema); - } - writeToByteOutputStream( - compressionType, parentStatement, resultsResp.getResults().getArrowBatches(), baos); - while (resultsResp.hasMoreRows) { - resultsResp = session.getDatabricksClient().getMoreResults(parentStatement); - writeToByteOutputStream( - compressionType, parentStatement, resultsResp.getResults().getArrowBatches(), baos); - } - return new ByteArrayInputStream(baos.toByteArray()); - } catch (DatabricksSQLException | IOException e) { - handleError(e); - } - return null; - } - - private void writeToByteOutputStream( - CompressionCodec compressionCodec, - IDatabricksStatementInternal parentStatement, - List arrowBatchList, - ByteArrayOutputStream baos) - throws DatabricksSQLException, IOException { - for (TSparkArrowBatch arrowBatch : arrowBatchList) { - byte[] decompressedBytes = - decompress( - arrowBatch.getBatch(), - compressionCodec, - String.format( - "Data fetch for inline arrow batch [%d] and statement [%s] with decompression algorithm : [%s]", - arrowBatch.getRowCount(), parentStatement, compressionCodec)); - totalRows += arrowBatch.getRowCount(); - baos.write(decompressedBytes); - } - } - - private byte[] getSerializedSchema(TGetResultSetMetadataResp metadata) - throws DatabricksSQLException { - if (metadata.getArrowSchema() != null) { - return metadata.getArrowSchema(); - } - Schema arrowSchema = hiveSchemaToArrowSchema(metadata.getSchema()); - try { - return SchemaUtility.serialize(arrowSchema); - } catch (IOException e) { - handleError(e); - } - // should never reach here; - return null; - } - - private Schema hiveSchemaToArrowSchema(TTableSchema hiveSchema) - throws DatabricksParsingException { - List fields = new ArrayList<>(); - if (hiveSchema == null) { - return new Schema(fields); - } - try { - hiveSchema - .getColumns() - .forEach( - columnDesc -> { - try { - fields.add(getArrowField(columnDesc)); - } catch (SQLException e) { - throw new RuntimeException(e); - } - }); - } catch (RuntimeException e) { - handleError(e); - } - return new Schema(fields); - } - - private Field getArrowField(TColumnDesc columnDesc) throws SQLException { - TPrimitiveTypeEntry primitiveTypeEntry = getTPrimitiveTypeOrDefault(columnDesc.getTypeDesc()); - ArrowType arrowType = mapThriftToArrowType(primitiveTypeEntry.getType()); - FieldType fieldType = new FieldType(true, arrowType, null); - return new Field(columnDesc.getColumnName(), fieldType, null); - } - @VisibleForTesting void handleError(Exception e) throws DatabricksParsingException { String errorMessage = diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java new file mode 100644 index 000000000..08950339c --- /dev/null +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java @@ -0,0 +1,425 @@ +package com.databricks.jdbc.api.impl.arrow; + +import static com.databricks.jdbc.common.EnvironmentVariables.DEFAULT_RESULT_ROW_LIMIT; +import static com.databricks.jdbc.common.util.DatabricksTypeUtil.*; +import static com.databricks.jdbc.common.util.DecompressionUtil.decompress; + +import com.databricks.jdbc.api.impl.IExecutionResult; +import com.databricks.jdbc.api.internal.IDatabricksSession; +import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; +import com.databricks.jdbc.common.CompressionCodec; +import com.databricks.jdbc.exception.DatabricksParsingException; +import com.databricks.jdbc.exception.DatabricksSQLException; +import com.databricks.jdbc.log.JdbcLogger; +import com.databricks.jdbc.log.JdbcLoggerFactory; +import com.databricks.jdbc.model.client.thrift.generated.*; +import com.databricks.jdbc.model.core.ColumnInfo; +import com.databricks.jdbc.model.core.ColumnInfoTypeName; +import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode; +import com.google.common.annotations.VisibleForTesting; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.SchemaUtility; + +/** + * Lazy implementation for thrift-based inline Arrow results that fetches arrow batches on demand. + * Similar to LazyThriftResult but processes Arrow data instead of columnar thrift data. + */ +public class LazyThriftInlineArrowResult implements IExecutionResult { + private static final JdbcLogger LOGGER = + JdbcLoggerFactory.getLogger(LazyThriftInlineArrowResult.class); + + private TFetchResultsResp currentResponse; + private ArrowResultChunk currentChunk; + private ArrowResultChunkIterator currentChunkIterator; + private long globalRowIndex; + private final IDatabricksSession session; + private final IDatabricksStatementInternal statement; + private final int maxRows; + private boolean hasReachedEnd; + private boolean isClosed; + private long totalRowsFetched; + private List columnInfos; + + /** + * Creates a new LazyThriftInlineArrowResult that lazily fetches arrow data on demand. + * + * @param initialResponse the initial response from the server + * @param statement the statement that generated this result + * @param session the session to use for fetching additional data + * @throws DatabricksSQLException if the initial response cannot be processed + */ + public LazyThriftInlineArrowResult( + TFetchResultsResp initialResponse, + IDatabricksStatementInternal statement, + IDatabricksSession session) + throws DatabricksSQLException { + this.currentResponse = initialResponse; + this.statement = statement; + this.session = session; + this.maxRows = statement != null ? statement.getMaxRows() : DEFAULT_RESULT_ROW_LIMIT; + this.globalRowIndex = -1; + this.hasReachedEnd = false; + this.isClosed = false; + this.totalRowsFetched = 0; + + // Initialize column info from metadata + setColumnInfo(initialResponse.getResultSetMetadata()); + + // Load initial chunk + loadCurrentChunk(); + LOGGER.debug( + "LazyThriftInlineArrowResult initialized with {} rows in first chunk, hasMoreRows: {}", + currentChunk.numRows, + currentResponse.hasMoreRows); + } + + /** + * Gets the value at the specified column index for the current row. + * + * @param columnIndex the zero-based column index + * @return the value at the specified column + * @throws DatabricksSQLException if the result is closed, cursor is invalid, or column index is + * out of bounds + */ + @Override + public Object getObject(int columnIndex) throws DatabricksSQLException { + if (isClosed) { + throw new DatabricksSQLException( + "Result is already closed", DatabricksDriverErrorCode.STATEMENT_CLOSED); + } + if (globalRowIndex == -1) { + throw new DatabricksSQLException( + "Cursor is before first row", DatabricksDriverErrorCode.INVALID_STATE); + } + if (currentChunkIterator == null) { + throw new DatabricksSQLException( + "No current chunk available", DatabricksDriverErrorCode.INVALID_STATE); + } + if (columnIndex < 0 || columnIndex >= columnInfos.size()) { + throw new DatabricksSQLException( + "Column index out of bounds " + columnIndex, DatabricksDriverErrorCode.INVALID_STATE); + } + + ColumnInfo columnInfo = columnInfos.get(columnIndex); + ColumnInfoTypeName requiredType = columnInfo.getTypeName(); + String arrowMetadata = currentChunkIterator.getType(columnIndex); + if (arrowMetadata == null) { + arrowMetadata = columnInfo.getTypeText(); + } + + return ArrowStreamResult.getObjectWithComplexTypeHandling( + session, currentChunkIterator, columnIndex, requiredType, arrowMetadata, columnInfo); + } + + /** + * Gets the current row index (0-based). Returns -1 if before the first row. + * + * @return the current row index + */ + @Override + public long getCurrentRow() { + return globalRowIndex; + } + + /** + * Moves the cursor to the next row. Fetches additional data from server if needed. + * + * @return true if there is a next row, false if at the end + * @throws DatabricksSQLException if an error occurs while fetching data + */ + @Override + public boolean next() throws DatabricksSQLException { + if (isClosed || hasReachedEnd) { + return false; + } + + if (!hasNext()) { + return false; + } + + // Check if we've reached the maxRows limit + boolean hasRowLimit = maxRows > 0; + if (hasRowLimit && globalRowIndex + 1 >= maxRows) { + hasReachedEnd = true; + return false; + } + + // Try to advance in current chunk + if (currentChunkIterator != null && currentChunkIterator.hasNextRow()) { + boolean advanced = currentChunkIterator.nextRow(); + if (advanced) { + globalRowIndex++; + return true; + } + } + + // Need to fetch next chunk + while (currentResponse.hasMoreRows) { + fetchNextChunk(); + + // If we got a chunk with data, advance to first row + if (currentChunkIterator != null && currentChunkIterator.hasNextRow()) { + boolean advanced = currentChunkIterator.nextRow(); + if (advanced) { + globalRowIndex++; + return true; + } + } + } + + // No more data available + hasReachedEnd = true; + return false; + } + + /** + * Checks if there are more rows available without advancing the cursor. + * + * @return true if there are more rows, false otherwise + */ + @Override + public boolean hasNext() { + if (isClosed || hasReachedEnd) { + return false; + } + + // Check maxRows limit + boolean hasRowLimit = maxRows > 0; + if (hasRowLimit && globalRowIndex + 1 >= maxRows) { + return false; + } + + // Check if there are more rows in current chunk + if (currentChunkIterator != null && currentChunkIterator.hasNextRow()) { + return true; + } + + // Check if there are more chunks to fetch + return currentResponse.hasMoreRows; + } + + /** Closes this result and releases associated resources. */ + @Override + public void close() { + this.isClosed = true; + if (currentChunk != null) { + currentChunk.releaseChunk(); + } + this.currentChunk = null; + this.currentChunkIterator = null; + this.currentResponse = null; + LOGGER.debug( + "LazyThriftInlineArrowResult closed after fetching {} total rows", totalRowsFetched); + } + + /** + * Gets the number of rows in the current chunk. + * + * @return the number of rows in the current chunk + */ + @Override + public long getRowCount() { + return currentChunk != null ? currentChunk.numRows : 0; + } + + /** + * Gets the chunk count. Always returns 0 for lazy thrift inline arrow results. + * + * @return 0 (lazy results don't use chunks in the same sense as buffered results) + */ + @Override + public long getChunkCount() { + return 0; + } + + private void loadCurrentChunk() throws DatabricksSQLException { + try { + ByteArrayInputStream byteStream = createArrowByteStream(currentResponse); + long rowCount = getTotalRowsInResponse(currentResponse); + + ArrowResultChunk.Builder builder = + ArrowResultChunk.builder().withInputStream(byteStream, rowCount); + + if (statement != null) { + builder.withStatementId(statement.getStatementId()); + } + + currentChunk = builder.build(); + currentChunkIterator = currentChunk.getChunkIterator(); + totalRowsFetched += rowCount; + + LOGGER.debug( + "Loaded arrow chunk with {} rows, total fetched: {}", rowCount, totalRowsFetched); + } catch (DatabricksParsingException e) { + LOGGER.error("Failed to load current chunk: {}", e.getMessage()); + hasReachedEnd = true; + throw new DatabricksSQLException( + "Failed to process arrow data", DatabricksDriverErrorCode.INLINE_CHUNK_PARSING_ERROR); + } + } + + /** + * Fetches the next chunk of data from the server and creates arrow chunks. + * + * @throws DatabricksSQLException if the fetch operation fails + */ + private void fetchNextChunk() throws DatabricksSQLException { + try { + LOGGER.debug("Fetching next arrow chunk, current total rows fetched: {}", totalRowsFetched); + currentResponse = session.getDatabricksClient().getMoreResults(statement); + + // Release previous chunk to free memory + if (currentChunk != null) { + currentChunk.releaseChunk(); + } + + loadCurrentChunk(); + + LOGGER.debug( + "Fetched arrow chunk with {} rows, hasMoreRows: {}", + currentChunk.numRows, + currentResponse.hasMoreRows); + } catch (DatabricksSQLException e) { + LOGGER.error("Failed to fetch next arrow chunk: {}", e.getMessage()); + hasReachedEnd = true; + throw e; + } + } + + private ByteArrayInputStream createArrowByteStream(TFetchResultsResp resultsResp) + throws DatabricksParsingException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + CompressionCodec compressionType = + CompressionCodec.getCompressionMapping(resultsResp.getResultSetMetadata()); + try { + byte[] serializedSchema = getSerializedSchema(resultsResp.getResultSetMetadata()); + if (serializedSchema != null) { + baos.write(serializedSchema); + } + writeArrowBatchesToStream(compressionType, resultsResp.getResults().getArrowBatches(), baos); + return new ByteArrayInputStream(baos.toByteArray()); + } catch (DatabricksSQLException | IOException e) { + handleError(e); + } + return null; + } + + private void writeArrowBatchesToStream( + CompressionCodec compressionCodec, + List arrowBatchList, + ByteArrayOutputStream baos) + throws DatabricksSQLException, IOException { + for (TSparkArrowBatch arrowBatch : arrowBatchList) { + byte[] decompressedBytes = + decompress( + arrowBatch.getBatch(), + compressionCodec, + String.format( + "Data fetch for lazy inline arrow batch [%d] and statement [%s] with decompression algorithm : [%s]", + arrowBatch.getRowCount(), statement, compressionCodec)); + baos.write(decompressedBytes); + } + } + + private long getTotalRowsInResponse(TFetchResultsResp resultsResp) { + long totalRows = 0; + if (resultsResp.getResults() != null && resultsResp.getResults().getArrowBatches() != null) { + for (TSparkArrowBatch arrowBatch : resultsResp.getResults().getArrowBatches()) { + totalRows += arrowBatch.getRowCount(); + } + } + return totalRows; + } + + private byte[] getSerializedSchema(TGetResultSetMetadataResp metadata) + throws DatabricksSQLException { + if (metadata.getArrowSchema() != null) { + return metadata.getArrowSchema(); + } + Schema arrowSchema = hiveSchemaToArrowSchema(metadata.getSchema()); + try { + return SchemaUtility.serialize(arrowSchema); + } catch (IOException e) { + handleError(e); + } + return null; + } + + private Schema hiveSchemaToArrowSchema(TTableSchema hiveSchema) + throws DatabricksParsingException { + List fields = new ArrayList<>(); + if (hiveSchema == null) { + return new Schema(fields); + } + try { + hiveSchema + .getColumns() + .forEach( + columnDesc -> { + try { + fields.add(getArrowField(columnDesc)); + } catch (SQLException e) { + throw new RuntimeException(e); + } + }); + } catch (RuntimeException e) { + handleError(e); + } + return new Schema(fields); + } + + private Field getArrowField(TColumnDesc columnDesc) throws SQLException { + TPrimitiveTypeEntry primitiveTypeEntry = getTPrimitiveTypeOrDefault(columnDesc.getTypeDesc()); + ArrowType arrowType = mapThriftToArrowType(primitiveTypeEntry.getType()); + FieldType fieldType = new FieldType(true, arrowType, null); + return new Field(columnDesc.getColumnName(), fieldType, null); + } + + private void setColumnInfo(TGetResultSetMetadataResp resultManifest) { + columnInfos = new ArrayList<>(); + if (resultManifest.getSchema() == null) { + return; + } + for (TColumnDesc tColumnDesc : resultManifest.getSchema().getColumns()) { + columnInfos.add( + com.databricks.jdbc.common.util.DatabricksThriftUtil.getColumnInfoFromTColumnDesc( + tColumnDesc)); + } + } + + @VisibleForTesting + void handleError(Exception e) throws DatabricksParsingException { + String errorMessage = + String.format("Cannot process lazy thrift inline arrow format. Error: %s", e.getMessage()); + LOGGER.error(errorMessage); + throw new DatabricksParsingException( + errorMessage, e, DatabricksDriverErrorCode.INLINE_CHUNK_PARSING_ERROR); + } + + /** + * Gets the total number of rows fetched from the server so far. + * + * @return the total number of rows fetched from the server + */ + public long getTotalRowsFetched() { + return totalRowsFetched; + } + + /** + * Checks if all data has been fetched from the server. + * + * @return true if all data has been fetched (either reached end or maxRows limit) + */ + public boolean isCompletelyFetched() { + return hasReachedEnd || !currentResponse.hasMoreRows; + } +} diff --git a/src/test/java/com/databricks/jdbc/api/impl/ExecutionResultFactoryTest.java b/src/test/java/com/databricks/jdbc/api/impl/ExecutionResultFactoryTest.java index 2efb1e33a..1e1461592 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/ExecutionResultFactoryTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/ExecutionResultFactoryTest.java @@ -1,10 +1,10 @@ package com.databricks.jdbc.api.impl; -import static com.databricks.jdbc.TestConstants.ARROW_BATCH_LIST; import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.when; import com.databricks.jdbc.api.impl.arrow.ArrowStreamResult; +import com.databricks.jdbc.api.impl.arrow.LazyThriftInlineArrowResult; import com.databricks.jdbc.api.impl.volume.VolumeOperationResult; import com.databricks.jdbc.api.internal.IDatabricksConnectionContext; import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; @@ -128,14 +128,11 @@ public void testGetResultSet_thriftURL() throws SQLException { @Test public void testGetResultSet_thriftInlineArrow() throws SQLException { - when(connectionContext.getConnectionUuid()).thenReturn("sample-uuid"); when(resultSetMetadataResp.getResultFormat()).thenReturn(TSparkRowSetType.ARROW_BASED_SET); when(fetchResultsResp.getResultSetMetadata()).thenReturn(resultSetMetadataResp); when(fetchResultsResp.getResults()).thenReturn(tRowSet); - when(session.getConnectionContext()).thenReturn(connectionContext); - when(tRowSet.getArrowBatches()).thenReturn(ARROW_BATCH_LIST); IExecutionResult result = ExecutionResultFactory.getResultSet(fetchResultsResp, session, parentStatement); - assertInstanceOf(ArrowStreamResult.class, result); + assertInstanceOf(LazyThriftInlineArrowResult.class, result); } } diff --git a/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResultTest.java b/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResultTest.java index 5f42fbdf1..9f2eb213a 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResultTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResultTest.java @@ -133,25 +133,6 @@ public void testIteration() throws Exception { assertFalse(result.next()); } - @Test - public void testInlineArrow() throws DatabricksSQLException { - IDatabricksConnectionContext connectionContext = - DatabricksConnectionContextFactory.create(JDBC_URL, new Properties()); - when(session.getConnectionContext()).thenReturn(connectionContext); - when(metadataResp.getSchema()).thenReturn(TEST_TABLE_SCHEMA); - when(fetchResultsResp.getResults()).thenReturn(resultData); - when(fetchResultsResp.getResultSetMetadata()).thenReturn(metadataResp); - ArrowStreamResult result = - new ArrowStreamResult(fetchResultsResp, true, parentStatement, session); - assertEquals(-1, result.getCurrentRow()); - assertTrue(result.hasNext()); - assertFalse(result.next()); - assertEquals(0, result.getCurrentRow()); - assertFalse(result.hasNext()); - assertDoesNotThrow(result::close); - assertFalse(result.hasNext()); - } - @Test public void testCloudFetchArrow() throws Exception { IDatabricksConnectionContext connectionContext = @@ -164,7 +145,7 @@ public void testCloudFetchArrow() throws Exception { when(fetchResultsResp.getResultSetMetadata()).thenReturn(metadataResp); when(parentStatement.getStatementId()).thenReturn(STATEMENT_ID); ArrowStreamResult result = - new ArrowStreamResult(fetchResultsResp, false, parentStatement, session, mockHttpClient); + new ArrowStreamResult(fetchResultsResp, parentStatement, session, mockHttpClient); assertEquals(-1, result.getCurrentRow()); assertTrue(result.hasNext()); assertDoesNotThrow(result::close); diff --git a/src/test/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProviderTest.java b/src/test/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProviderTest.java index 86be512d4..8392daf68 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProviderTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProviderTest.java @@ -1,27 +1,17 @@ package com.databricks.jdbc.api.impl.arrow; -import static com.databricks.jdbc.TestConstants.ARROW_BATCH_LIST; -import static com.databricks.jdbc.TestConstants.TEST_TABLE_SCHEMA; import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import com.databricks.jdbc.api.internal.IDatabricksSession; -import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; import com.databricks.jdbc.common.CompressionCodec; -import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; -import com.databricks.jdbc.model.client.thrift.generated.TFetchResultsResp; -import com.databricks.jdbc.model.client.thrift.generated.TGetResultSetMetadataResp; -import com.databricks.jdbc.model.client.thrift.generated.TRowSet; -import com.databricks.jdbc.model.client.thrift.generated.TSparkArrowBatch; import com.databricks.jdbc.model.core.ColumnInfo; import com.databricks.jdbc.model.core.ColumnInfoTypeName; import com.databricks.jdbc.model.core.ResultData; import com.databricks.jdbc.model.core.ResultManifest; import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.util.Collections; import net.jpountz.lz4.LZ4FrameOutputStream; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -37,41 +27,9 @@ public class InlineChunkProviderTest { private static final long TOTAL_ROWS = 2L; - @Mock TGetResultSetMetadataResp metadata; - @Mock TFetchResultsResp fetchResultsResp; - @Mock IDatabricksStatementInternal parentStatement; - @Mock IDatabricksSession session; @Mock private ResultData mockResultData; @Mock private ResultManifest mockResultManifest; - @Test - void testInitialisation() throws DatabricksParsingException { - when(fetchResultsResp.getResultSetMetadata()).thenReturn(metadata); - when(metadata.getArrowSchema()).thenReturn(null); - when(metadata.getSchema()).thenReturn(TEST_TABLE_SCHEMA); - when(fetchResultsResp.getResults()).thenReturn(new TRowSet().setArrowBatches(ARROW_BATCH_LIST)); - when(metadata.isSetLz4Compressed()).thenReturn(false); - InlineChunkProvider inlineChunkProvider = - new InlineChunkProvider(fetchResultsResp, parentStatement, session); - assertTrue(inlineChunkProvider.hasNextChunk()); - assertTrue(inlineChunkProvider.next()); - assertFalse(inlineChunkProvider.next()); - } - - @Test - void handleErrorTest() throws DatabricksParsingException { - TSparkArrowBatch arrowBatch = - new TSparkArrowBatch().setRowCount(0).setBatch(new byte[] {65, 66, 67}); - when(fetchResultsResp.getResultSetMetadata()).thenReturn(metadata); - when(fetchResultsResp.getResults()) - .thenReturn(new TRowSet().setArrowBatches(Collections.singletonList(arrowBatch))); - InlineChunkProvider inlineChunkProvider = - new InlineChunkProvider(fetchResultsResp, parentStatement, session); - assertThrows( - DatabricksParsingException.class, - () -> inlineChunkProvider.handleError(new RuntimeException())); - } - @Test void testConstructorSuccessfulCreation() throws DatabricksSQLException, IOException { // Create valid Arrow data with two rows and one column: [1, 2] diff --git a/src/test/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResultTest.java b/src/test/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResultTest.java new file mode 100644 index 000000000..9c43d3e78 --- /dev/null +++ b/src/test/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResultTest.java @@ -0,0 +1,285 @@ +package com.databricks.jdbc.api.impl.arrow; + +import static com.databricks.jdbc.TestConstants.TEST_TABLE_SCHEMA; +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +import com.databricks.jdbc.api.internal.IDatabricksSession; +import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; +import com.databricks.jdbc.dbclient.impl.common.StatementId; +import com.databricks.jdbc.exception.DatabricksParsingException; +import com.databricks.jdbc.exception.DatabricksSQLException; +import com.databricks.jdbc.model.client.thrift.generated.*; +import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode; +import java.io.IOException; +import java.util.Collections; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +public class LazyThriftInlineArrowResultTest { + + @Mock private IDatabricksSession session; + @Mock private IDatabricksStatementInternal statement; + private static final StatementId STATEMENT_ID = new StatementId("test_statement_id"); + private static final byte[] DUMMY_ARROW_BYTES = new byte[] {65, 66, 67}; + + private TFetchResultsResp createFetchResultsResp(int rowCount, boolean hasMoreRows) { + TSparkArrowBatch arrowBatch = + new TSparkArrowBatch().setRowCount(rowCount).setBatch(DUMMY_ARROW_BYTES); + TRowSet rowSet = new TRowSet().setArrowBatches(Collections.singletonList(arrowBatch)); + + TGetResultSetMetadataResp metadata = + new TGetResultSetMetadataResp().setSchema(TEST_TABLE_SCHEMA); + + TFetchResultsResp response = + new TFetchResultsResp().setResultSetMetadata(metadata).setResults(rowSet); + response.hasMoreRows = hasMoreRows; + + return response; + } + + @Test + void testConstructorInitializesCorrectly() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertEquals(-1, result.getCurrentRow()); + assertEquals(0, result.getRowCount()); + assertEquals(0, result.getTotalRowsFetched()); + assertFalse(result.hasNext()); + assertTrue(result.isCompletelyFetched()); + } + + @Test + void testGetObjectThrowsWhenClosed() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + result.close(); + + DatabricksSQLException exception = + assertThrows(DatabricksSQLException.class, () -> result.getObject(0)); + assertEquals("Result is already closed", exception.getMessage()); + assertEquals(DatabricksDriverErrorCode.STATEMENT_CLOSED.name(), exception.getSQLState()); + } + + @Test + void testGetObjectThrowsWhenBeforeFirstRow() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + DatabricksSQLException exception = + assertThrows(DatabricksSQLException.class, () -> result.getObject(0)); + assertEquals("Cursor is before first row", exception.getMessage()); + assertEquals(DatabricksDriverErrorCode.INVALID_STATE.name(), exception.getSQLState()); + } + + @Test + void testCloseReleasesResources() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + result.close(); + + assertFalse(result.hasNext()); + assertFalse(result.next()); + } + + @Test + void testIsCompletelyFetchedWhenNoMoreRows() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertTrue(result.isCompletelyFetched()); + } + + @Test + void testIsCompletelyFetchedWithMoreRows() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, true); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertFalse(result.isCompletelyFetched()); + } + + @Test + void testGetChunkCount() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertEquals(0, result.getChunkCount()); + } + + @Test + void testHandleErrorThrowsParsingException() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + Exception testException = new IOException("Test error"); + DatabricksParsingException exception = + assertThrows(DatabricksParsingException.class, () -> result.handleError(testException)); + assertTrue(exception.getMessage().contains("Cannot process lazy thrift inline arrow format")); + assertEquals( + DatabricksDriverErrorCode.INLINE_CHUNK_PARSING_ERROR.name(), exception.getSQLState()); + } + + @Test + void testEmptyResultSet() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertEquals(-1, result.getCurrentRow()); + assertFalse(result.hasNext()); + assertFalse(result.next()); + assertEquals(0, result.getRowCount()); + assertTrue(result.isCompletelyFetched()); + } + + @Test + void testNullStatement() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, null, session); + + assertEquals(-1, result.getCurrentRow()); + assertEquals(0, result.getRowCount()); + } + + @Test + void testGetCurrentRowBeforeNext() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertEquals(-1, result.getCurrentRow()); + } + + @Test + void testGetTotalRowsFetched() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertEquals(0, result.getTotalRowsFetched()); + } + + @Test + void testNextReturnsFalseOnEmptyResultSet() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertFalse(result.next()); + } + + @Test + void testHasNextReturnsFalseOnEmptyResultSet() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertFalse(result.hasNext()); + } + + @Test + void testNextReturnsFalseAfterClose() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + result.close(); + + assertFalse(result.next()); + } + + @Test + void testHasNextReturnsFalseAfterClose() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + result.close(); + + assertFalse(result.hasNext()); + } + + @Test + void testConstructorWithNullStatementUsesDefaultMaxRows() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, null, session); + + assertNotNull(result); + assertEquals(-1, result.getCurrentRow()); + } +}