Skip to content

Commit ad30121

Browse files
authored
[Improvement] Improve GraphAr spark writer performance and implement custom writer builder to bypass spark's write behavior (#92)
1 parent d26d3b8 commit ad30121

File tree

16 files changed

+906
-205
lines changed

16 files changed

+906
-205
lines changed

spark/src/main/java/com/alibaba/graphar/GeneralParams.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,7 @@ public class GeneralParams {
2626
public static final String vertexChunkIndexCol = "_graphArVertexChunkIndex";
2727
public static final String edgeIndexCol = "_graphArEdgeIndex";
2828
public static final String regularSeperator = "_";
29+
public static final String offsetStartChunkIndexKey = "_graphar_offset_start_chunk_index";
30+
public static final String aggNumListOfEdgeChunkKey = "_graphar_agg_num_list_of_edge_chunk";
2931
}
3032

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/** Copyright 2022 Alibaba Group Holding Limited.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
package com.alibaba.graphar.datasources
17+
18+
import com.alibaba.graphar.GeneralParams
19+
20+
import org.json4s._
21+
import org.json4s.jackson.JsonMethods._
22+
23+
import org.apache.spark.internal.io.FileCommitProtocol
24+
import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol
25+
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
26+
import org.apache.hadoop.mapreduce._
27+
import org.apache.spark.internal.Logging
28+
29+
object GarCommitProtocol {
30+
private def binarySearchPair(aggNums: Array[Int], key: Int): (Int, Int) = {
31+
var low = 0
32+
var high = aggNums.length - 1
33+
var mid = 0
34+
while (low <= high) {
35+
mid = (high + low) / 2;
36+
if (aggNums(mid) <= key && (mid == aggNums.length - 1 || aggNums(mid + 1) > key)) {
37+
return (mid, key - aggNums(mid))
38+
} else if (aggNums(mid) > key) {
39+
high = mid - 1
40+
} else {
41+
low = mid + 1
42+
}
43+
}
44+
return (low, key - aggNums(low))
45+
}
46+
}
47+
48+
class GarCommitProtocol(jobId: String,
49+
path: String,
50+
options: Map[String, String],
51+
dynamicPartitionOverwrite: Boolean = false)
52+
extends SQLHadoopMapReduceCommitProtocol(jobId, path, dynamicPartitionOverwrite) with Serializable with Logging {
53+
54+
override def getFilename(taskContext: TaskAttemptContext, ext: String): String = {
55+
val partitionId = taskContext.getTaskAttemptID.getTaskID.getId
56+
if (options.contains(GeneralParams.offsetStartChunkIndexKey)) {
57+
// offset chunk file name, looks like chunk0
58+
val chunk_index = options.get(GeneralParams.offsetStartChunkIndexKey).get.toInt + partitionId
59+
return f"chunk$chunk_index"
60+
}
61+
if (options.contains(GeneralParams.aggNumListOfEdgeChunkKey)) {
62+
// edge chunk file name, looks like part0/chunk0
63+
val jValue = parse(options.get(GeneralParams.aggNumListOfEdgeChunkKey).get)
64+
implicit val formats = DefaultFormats // initialize a default formats for json4s
65+
val aggNums: Array[Int] = Extraction.extract[Array[Int]](jValue)
66+
val chunkPair: (Int, Int) = GarCommitProtocol.binarySearchPair(aggNums, partitionId)
67+
val vertex_chunk_index: Int = chunkPair._1
68+
val edge_chunk_index: Int = chunkPair._2
69+
return f"part$vertex_chunk_index/chunk$edge_chunk_index"
70+
}
71+
// vertex chunk file name, looks like chunk0
72+
return f"chunk$partitionId"
73+
}
74+
}

spark/src/main/scala/com/alibaba/graphar/datasources/GarDataSource.scala

Lines changed: 80 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
/** Copyright 2022 Alibaba Group Holding Limited.
1+
/* Licensed to the Apache Software Foundation (ASF) under one or more
2+
* contributor license agreements. See the NOTICE file distributed with
3+
* this work for additional information regarding copyright ownership.
4+
* The ASF licenses this file to You under the Apache License, Version 2.0
5+
* (the "License"); you may not use this file except in compliance with
6+
* the License. You may obtain a copy of the License at
27
*
3-
* Licensed under the Apache License, Version 2.0 (the "License");
4-
* you may not use this file except in compliance with the License.
5-
* You may obtain a copy of the License at
6-
*
7-
* http://www.apache.org/licenses/LICENSE-2.0
8+
* http://www.apache.org/licenses/LICENSE-2.0
89
*
910
* Unless required by applicable law or agreed to in writing, software
1011
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -15,38 +16,104 @@
1516

1617
package com.alibaba.graphar.datasources
1718

18-
import org.apache.spark.sql.connector.catalog.Table
19+
import scala.collection.JavaConverters._
20+
import java.util
21+
22+
import com.fasterxml.jackson.databind.ObjectMapper
23+
import org.apache.hadoop.conf.Configuration
24+
import org.apache.hadoop.fs.Path
25+
26+
import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
1927
import org.apache.spark.sql.execution.datasources._
28+
import org.apache.spark.sql.SparkSession
2029
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
2130
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
2231
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
2332
import org.apache.spark.sql.execution.datasources.v2._
2433
import org.apache.spark.sql.types.StructType
2534
import org.apache.spark.sql.util.CaseInsensitiveStringMap
35+
import org.apache.spark.sql.sources.DataSourceRegister
36+
import org.apache.spark.sql.connector.expressions.Transform
2637

27-
/** GarDataSource is a class to provide gar files as the data source for spark. */
28-
class GarDataSource extends FileDataSourceV2 {
38+
import com.alibaba.graphar.utils.Utils
2939

40+
object GarUtils
41+
/** GarDataSource is a class to provide gar files as the data source for spark. */
42+
class GarDataSource extends TableProvider with DataSourceRegister {
3043
/** The default fallback file format is Parquet. */
31-
override def fallbackFileFormat: Class[_ <: FileFormat] = classOf[ParquetFileFormat]
44+
def fallbackFileFormat: Class[_ <: FileFormat] = classOf[ParquetFileFormat]
45+
46+
lazy val sparkSession = SparkSession.active
3247

3348
/** The string that represents the format name. */
3449
override def shortName(): String = "gar"
3550

51+
protected def getPaths(map: CaseInsensitiveStringMap): Seq[String] = {
52+
val objectMapper = new ObjectMapper()
53+
val paths = Option(map.get("paths")).map { pathStr =>
54+
objectMapper.readValue(pathStr, classOf[Array[String]]).toSeq
55+
}.getOrElse(Seq.empty)
56+
paths ++ Option(map.get("path")).toSeq
57+
}
58+
59+
protected def getOptionsWithoutPaths(map: CaseInsensitiveStringMap): CaseInsensitiveStringMap = {
60+
val withoutPath = map.asCaseSensitiveMap().asScala.filterKeys { k =>
61+
!k.equalsIgnoreCase("path") && !k.equalsIgnoreCase("paths")
62+
}
63+
new CaseInsensitiveStringMap(withoutPath.toMap.asJava)
64+
}
65+
66+
protected def getTableName(map: CaseInsensitiveStringMap, paths: Seq[String]): String = {
67+
val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(
68+
map.asCaseSensitiveMap().asScala.toMap)
69+
val name = shortName() + " " + paths.map(qualifiedPathName(_, hadoopConf)).mkString(",")
70+
Utils.redact(sparkSession.sessionState.conf.stringRedactionPattern, name)
71+
}
72+
73+
private def qualifiedPathName(path: String, hadoopConf: Configuration): String = {
74+
val hdfsPath = new Path(path)
75+
val fs = hdfsPath.getFileSystem(hadoopConf)
76+
hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toString
77+
}
78+
3679
/** Provide a table from the data source. */
37-
override def getTable(options: CaseInsensitiveStringMap): Table = {
80+
def getTable(options: CaseInsensitiveStringMap): Table = {
3881
val paths = getPaths(options)
3982
val tableName = getTableName(options, paths)
4083
val optionsWithoutPaths = getOptionsWithoutPaths(options)
4184
GarTable(tableName, sparkSession, optionsWithoutPaths, paths, None, getFallbackFileFormat(options))
4285
}
4386

4487
/** Provide a table from the data source with specific schema. */
45-
override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = {
88+
def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = {
4689
val paths = getPaths(options)
4790
val tableName = getTableName(options, paths)
4891
val optionsWithoutPaths = getOptionsWithoutPaths(options)
49-
GarTable(tableName, sparkSession, optionsWithoutPaths, paths, Some(schema), getFallbackFileFormat(options))
92+
GarTable(tableName, sparkSession, optionsWithoutPaths, paths, Some(schema), getFallbackFileFormat(options))
93+
}
94+
95+
override def supportsExternalMetadata(): Boolean = true
96+
97+
private var t: Table = null
98+
99+
override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
100+
if (t == null) t = getTable(options)
101+
t.schema()
102+
}
103+
104+
override def inferPartitioning(options: CaseInsensitiveStringMap): Array[Transform] = {
105+
Array.empty
106+
}
107+
108+
override def getTable(schema: StructType,
109+
partitioning: Array[Transform],
110+
properties: util.Map[String, String]): Table = {
111+
// If the table is already loaded during schema inference, return it directly.
112+
if (t != null) {
113+
t
114+
} else {
115+
getTable(new CaseInsensitiveStringMap(properties), schema)
116+
}
50117
}
51118

52119
// Get the actual fall back file format.

spark/src/main/scala/com/alibaba/graphar/datasources/GarTable.scala

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,27 +20,28 @@ import scala.collection.JavaConverters._
2020
import org.apache.hadoop.fs.FileStatus
2121

2222
import org.apache.spark.sql.SparkSession
23-
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, Write, WriteBuilder}
23+
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
2424
import org.apache.spark.sql.catalyst.csv.CSVOptions
2525
import org.apache.spark.sql.execution.datasources.FileFormat
2626
import org.apache.spark.sql.execution.datasources.csv.CSVDataSource
2727
import org.apache.spark.sql.execution.datasources.orc.OrcUtils
2828
import org.apache.spark.sql.execution.datasources.parquet.ParquetUtils
2929
import org.apache.spark.sql.execution.datasources.v2.FileTable
30-
import org.apache.spark.sql.execution.datasources.v2.csv.CSVWrite
31-
import org.apache.spark.sql.execution.datasources.v2.orc.OrcWrite
32-
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetWrite
3330
import org.apache.spark.sql.types._
3431
import org.apache.spark.sql.util.CaseInsensitiveStringMap
3532

33+
import com.alibaba.graphar.datasources.csv.CSVWriteBuilder
34+
import com.alibaba.graphar.datasources.parquet.ParquetWriteBuilder
35+
import com.alibaba.graphar.datasources.orc.OrcWriteBuilder
36+
37+
3638
/** GarTable is a class to represent the graph data in GraphAr as a table. */
37-
case class GarTable(
38-
name: String,
39-
sparkSession: SparkSession,
40-
options: CaseInsensitiveStringMap,
41-
paths: Seq[String],
42-
userSpecifiedSchema: Option[StructType],
43-
fallbackFileFormat: Class[_ <: FileFormat])
39+
case class GarTable(name: String,
40+
sparkSession: SparkSession,
41+
options: CaseInsensitiveStringMap,
42+
paths: Seq[String],
43+
userSpecifiedSchema: Option[StructType],
44+
fallbackFileFormat: Class[_ <: FileFormat])
4445
extends FileTable(sparkSession, options, paths, userSpecifiedSchema) {
4546

4647
/** Construct a new scan builder. */
@@ -51,28 +52,22 @@ case class GarTable(
5152
override def inferSchema(files: Seq[FileStatus]): Option[StructType] = formatName match {
5253
case "csv" => {
5354
val parsedOptions = new CSVOptions(
54-
options.asScala.toMap,
55-
columnPruning = sparkSession.sessionState.conf.csvColumnPruning,
56-
sparkSession.sessionState.conf.sessionLocalTimeZone)
55+
options.asScala.toMap,
56+
columnPruning = sparkSession.sessionState.conf.csvColumnPruning,
57+
sparkSession.sessionState.conf.sessionLocalTimeZone)
5758

5859
CSVDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions)
5960
}
6061
case "orc" => OrcUtils.inferSchema(sparkSession, files, options.asScala.toMap)
6162
case "parquet" => ParquetUtils.inferSchema(sparkSession, options.asScala.toMap, files)
6263
case _ => throw new IllegalArgumentException
6364
}
64-
65+
6566
/** Construct a new write builder according to the actual file format. */
6667
override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = formatName match {
67-
case "csv" => new WriteBuilder {
68-
override def build(): Write = CSVWrite(paths, formatName, supportsDataType, info)
69-
}
70-
case "orc" => new WriteBuilder {
71-
override def build(): Write = OrcWrite(paths, formatName, supportsDataType, info)
72-
}
73-
case "parquet" => new WriteBuilder {
74-
override def build(): Write = ParquetWrite(paths, formatName, supportsDataType, info)
75-
}
68+
case "csv" => new CSVWriteBuilder(paths, formatName, supportsDataType, info)
69+
case "orc" => new OrcWriteBuilder(paths, formatName, supportsDataType, info)
70+
case "parquet" => new ParquetWriteBuilder(paths, formatName, supportsDataType, info)
7671
case _ => throw new IllegalArgumentException
7772
}
7873

0 commit comments

Comments
 (0)