Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@ object SparkUtils extends Logging {
val fields = new mutable.ListBuffer[Column]()
val stringFields = new mutable.ListBuffer[String]()
val usedNames = new mutable.HashSet[String]()
val isUseArrayGet = df.sparkSession.version.split('.').head.toInt >= 4

def getArrayIndexExpr(path: String, index: Int): String = {
// Since Spark 4 the behavior of index operator has changed and requires it to be in-bounds and throws exception
// otherwise: [INVALID_ARRAY_INDEX] The index 0 is out of bounds.
// So 'get()' Spark SQL function is used instead which is introduced in Spark 4.
// Older Spark versions need to use the indexing via square brackets.
if (isUseArrayGet) {
s"get($path, $index)"
} else {
s"$path[$index]"
}
}

def getNewFieldName(desiredName: String): String = {
var name = desiredName
Expand All @@ -102,21 +115,22 @@ object SparkUtils extends Logging {
*/
def flattenStructArray(path: String, fieldNamePrefix: String, structField: StructField, arrayType: ArrayType): Unit = {
val maxInd = getMaxArraySize(s"$path${structField.name}")
val fieldName = s"$path`${structField.name}`"
var i = 0
while (i < maxInd) {
arrayType.elementType match {
case st: StructType =>
val newFieldNamePrefix = s"${fieldNamePrefix}${i}_"
flattenGroup(s"$path`${structField.name}`[$i].", newFieldNamePrefix, st)
flattenGroup(s"${getArrayIndexExpr(fieldName, i)}.", newFieldNamePrefix, st)
case ar: ArrayType =>
val newFieldNamePrefix = s"${fieldNamePrefix}${i}_"
flattenArray(s"$path`${structField.name}`[$i].", newFieldNamePrefix, structField, ar)
flattenArray(s"${getArrayIndexExpr(fieldName, i)}.", newFieldNamePrefix, structField, ar)
// AtomicType is protected on package 'sql' level so have to enumerate all subtypes :(
case _ =>
val newFieldNamePrefix = s"${fieldNamePrefix}${i}"
val newFieldName = getNewFieldName(s"$newFieldNamePrefix")
fields += expr(s"$path`${structField.name}`[$i]").as(newFieldName, structField.metadata)
stringFields += s"""expr("$path`${structField.name}`[$i] AS `$newFieldName`")"""
fields += expr(s"${getArrayIndexExpr(fieldName, i)}").as(newFieldName, structField.metadata)
stringFields += s"""expr("${getArrayIndexExpr(fieldName, i)} AS `$newFieldName`")"""
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
i += 1
}
Expand All @@ -128,17 +142,17 @@ object SparkUtils extends Logging {
while (i < maxInd) {
arrayType.elementType match {
case st: StructType =>
val newFieldNamePrefix = s"${fieldNamePrefix}${i}_"
flattenGroup(s"$path[$i]", newFieldNamePrefix, st)
val newFieldNamePrefix = s"$fieldNamePrefix${i}_"
flattenGroup(s"${getArrayIndexExpr(path, i)}", newFieldNamePrefix, st)
case ar: ArrayType =>
val newFieldNamePrefix = s"${fieldNamePrefix}${i}_"
flattenNestedArrays(s"$path[$i]", newFieldNamePrefix, ar, metadata)
val newFieldNamePrefix = s"$fieldNamePrefix${i}_"
flattenNestedArrays(s"${getArrayIndexExpr(path, i)}", newFieldNamePrefix, ar, metadata)
// AtomicType is protected on package 'sql' level so have to enumerate all subtypes :(
case _ =>
val newFieldNamePrefix = s"${fieldNamePrefix}${i}"
val newFieldNamePrefix = s"$fieldNamePrefix${i}"
val newFieldName = getNewFieldName(s"$newFieldNamePrefix")
fields += expr(s"$path[$i]").as(newFieldName, metadata)
stringFields += s"""expr("$path`[$i] AS `$newFieldName`")"""
fields += expr(s"${getArrayIndexExpr(path, i)}").as(newFieldName, metadata)
stringFields += s"""expr("${getArrayIndexExpr(path, i)} AS `$newFieldName`")"""
}
i += 1
}
Expand Down Expand Up @@ -183,7 +197,7 @@ object SparkUtils extends Logging {
case _ =>
val newFieldName = getNewFieldName(s"$fieldNamePrefix${field.name}")
fields += expr(s"$path`${field.name}`").as(newFieldName, field.metadata)
if (path.contains('['))
if (path.contains('[') || path.contains('('))
stringFields += s"""expr("$path`${field.name}` AS `$newFieldName`")"""
else
stringFields += s"""col("$path`${field.name}`").as("$newFieldName")"""
Expand Down
Loading