diff --git a/.github/workflows/scala.yml b/.github/workflows/scala.yml
new file mode 100644
index 00000000..dc412fd9
--- /dev/null
+++ b/.github/workflows/scala.yml
@@ -0,0 +1,82 @@
+name: Build Spark sql perf
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ branches:
+ - master
+
+jobs:
+ scalafmt-check:
+ name: Scalafmt Check
+ runs-on: ubuntu-22.04
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4
+
+ - name: Set up JDK 11
+ uses: actions/setup-java@v3
+ with:
+ java-version: '11'
+ distribution: 'adopt'
+
+ - name: Cache sbt
+ uses: actions/cache@v4
+ with:
+ path: |
+ ~/.ivy2/cache
+ ~/.sbt
+ ~/.coursier
+ key: ${{ runner.os }}-sbt-${{ hashFiles('**/build.sbt') }}
+ restore-keys: |
+ ${{ runner.os }}-sbt-
+
+ - name: Check formatting with scalafmt
+ run: sbt scalafmtCheck
+
+ build:
+ name: Build & Package
+ runs-on: ubuntu-22.04
+ needs: scalafmt-check
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4
+
+ - name: Set up JDK 11
+ uses: actions/setup-java@v3
+ with:
+ java-version: '11'
+ distribution: 'adopt'
+
+ - name: Cache sbt
+ uses: actions/cache@v4
+ with:
+ path: |
+ ~/.ivy2/cache
+ ~/.sbt
+ ~/.coursier
+ key: ${{ runner.os }}-sbt-${{ hashFiles('**/build.sbt') }}
+ restore-keys: |
+ ${{ runner.os }}-sbt-
+
+ - name: Compile
+ run: sbt compile
+
+ - name: Package
+ run: sbt package
+
+ - name: Extract version
+ id: extract_version
+ run: |
+ version=$(cat version.sbt | grep 'version in ThisBuild :=' | awk -F'"' '{print $2}')
+ echo "version=$version" >> $GITHUB_ENV
+
+ - name: Upload JAR artifact
+ uses: actions/upload-artifact@v4
+ with:
+ name: spark-sql-perf_2.12-${{ env.version }}.jar
+ path: target/scala-2.12/*.jar
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index 1bcb62a0..fec77466 100644
--- a/.gitignore
+++ b/.gitignore
@@ -19,3 +19,5 @@ src_managed/
project/boot/
project/plugins/project/
performance/
+/.bloop/
+/build/*.zip
diff --git a/.scalafmt.conf b/.scalafmt.conf
new file mode 100644
index 00000000..8a313d91
--- /dev/null
+++ b/.scalafmt.conf
@@ -0,0 +1,18 @@
+version = "3.8.1"
+runner.dialect = scala212
+
+maxColumn = 100
+
+indent.main = 2
+indent.significant = 2
+
+align.preset = more
+align.tokens."+" = [
+ { code = "->", owner = "Term.ApplyInfix" }
+]
+
+newlines.alwaysBeforeElseAfterCurlyIf = false
+newlines.beforeCurlyLambdaParams = multilineWithCaseOnly
+
+rewrite.rules = [RedundantBraces, SortImports]
+rewrite.redundantBraces.stringInterpolation = true
diff --git a/bin/run b/bin/run
index 7d28227c..f8923ffc 100755
--- a/bin/run
+++ b/bin/run
@@ -3,4 +3,4 @@
# runs spark-sql-perf from the current directory
ARGS="runBenchmark $@"
-build/sbt "$ARGS"
\ No newline at end of file
+sbt "$ARGS"
\ No newline at end of file
diff --git a/build.sbt b/build.sbt
index 2303e62d..d44fca83 100644
--- a/build.sbt
+++ b/build.sbt
@@ -5,65 +5,64 @@ name := "spark-sql-perf"
organization := "com.databricks"
-scalaVersion := "2.12.10"
+scalaVersion := "2.12.18"
-crossScalaVersions := Seq("2.12.10")
+crossScalaVersions := Seq("2.12.18")
-sparkPackageName := "databricks/spark-sql-perf"
+// Remove publishing configuration for now - focus on compilation
+// sparkPackageName := "databricks/spark-sql-perf"
// All Spark Packages need a license
licenses := Seq("Apache-2.0" -> url("http://opensource.org/licenses/Apache-2.0"))
-sparkVersion := "3.0.0"
+// Spark version - define it manually since we removed the spark-packages plugin
+val sparkVersion = "3.5.1"
-sparkComponents ++= Seq("sql", "hive", "mllib")
+// Add Spark dependencies manually
+libraryDependencies ++= Seq(
+ "org.apache.spark" %% "spark-core" % sparkVersion % "provided",
+ "org.apache.spark" %% "spark-sql" % sparkVersion % "provided",
+ "org.apache.spark" %% "spark-hive" % sparkVersion % "provided",
+ "org.apache.spark" %% "spark-mllib" % sparkVersion % "provided"
+)
-initialCommands in console :=
+initialCommands / console :=
"""
|import org.apache.spark.sql._
|import org.apache.spark.sql.functions._
|import org.apache.spark.sql.types._
- |import org.apache.spark.sql.hive.test.TestHive
- |import TestHive.implicits
- |import TestHive.sql
+ |import org.apache.spark.sql.SparkSession
|
- |val sqlContext = TestHive
+ |val spark = SparkSession.builder().appName("spark-sql-perf").getOrCreate()
+ |val sqlContext = spark.sqlContext
|import sqlContext.implicits._
""".stripMargin
-libraryDependencies += "com.github.scopt" %% "scopt" % "3.7.1"
+libraryDependencies += "com.github.scopt" %% "scopt" % "4.1.0"
-libraryDependencies += "com.twitter" %% "util-jvm" % "6.45.0" % "provided"
+libraryDependencies += "com.twitter" %% "util-jvm" % "24.2.0" % "provided"
-libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.5" % "test"
+libraryDependencies += "org.scalatest" %% "scalatest" % "3.2.19" % "test"
-libraryDependencies += "org.yaml" % "snakeyaml" % "1.23"
+libraryDependencies += "org.yaml" % "snakeyaml" % "2.5"
fork := true
-// Your username to login to Databricks Cloud
-dbcUsername := sys.env.getOrElse("DBC_USERNAME", "")
-
-// Your password (Can be set as an environment variable)
-dbcPassword := sys.env.getOrElse("DBC_PASSWORD", "")
-
-// The URL to the Databricks Cloud DB Api. Don't forget to set the port number to 34563!
-dbcApiUrl := sys.env.getOrElse ("DBC_URL", sys.error("Please set DBC_URL"))
-
-// Add any clusters that you would like to deploy your work to. e.g. "My Cluster"
-// or run dbcExecuteCommand
-dbcClusters += sys.env.getOrElse("DBC_USERNAME", "")
-
-dbcLibraryPath := s"/Users/${sys.env.getOrElse("DBC_USERNAME", "")}/lib"
+// Remove Databricks Cloud configuration for now
+// dbcUsername := sys.env.getOrElse("DBC_USERNAME", "")
+// dbcPassword := sys.env.getOrElse("DBC_PASSWORD", "")
+// dbcApiUrl := sys.env.getOrElse ("DBC_URL", sys.error("Please set DBC_URL"))
+// dbcClusters += sys.env.getOrElse("DBC_USERNAME", "")
+// dbcLibraryPath := s"/Users/${sys.env.getOrElse("DBC_USERNAME", "")}/lib"
val runBenchmark = inputKey[Unit]("runs a benchmark")
runBenchmark := {
import complete.DefaultParsers._
val args = spaceDelimited("[args]").parsed
- val scalaRun = (runner in run).value
- val classpath = (fullClasspath in Compile).value
+ val scalaRun = (Compile / run / runner).value
+ val classpath = (Compile / fullClasspath).value
scalaRun.run("com.databricks.spark.sql.perf.RunBenchmark", classpath.map(_.data), args,
streams.value.log)
}
@@ -74,13 +73,15 @@ val runMLBenchmark = inputKey[Unit]("runs an ML benchmark")
runMLBenchmark := {
import complete.DefaultParsers._
val args = spaceDelimited("[args]").parsed
- val scalaRun = (runner in run).value
- val classpath = (fullClasspath in Compile).value
+ val scalaRun = (Compile / run / runner).value
+ val classpath = (Compile / fullClasspath).value
scalaRun.run("com.databricks.spark.sql.perf.mllib.MLLib", classpath.map(_.data), args,
streams.value.log)
}
+// Comment out release configuration for now
+/*
import ReleaseTransformations._
/** Push to the team directory instead of the user's homedir for releases. */
@@ -159,3 +160,9 @@ releaseProcess := Seq[ReleaseStep](
commitNextVersion,
pushChanges
)
+*/
+
+assembly / assemblyMergeStrategy := {
+ case PathList("META-INF", xs @ _*) => MergeStrategy.discard
+ case x => MergeStrategy.first
+}
\ No newline at end of file
diff --git a/build/sbt b/build/sbt
index cc3203d7..7d26b548 100755
--- a/build/sbt
+++ b/build/sbt
@@ -153,4 +153,4 @@ trap onExit INT
run "$@"
exit_status=$?
-onExit
+onExit
\ No newline at end of file
diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash
index 2a399365..707f70ef 100755
--- a/build/sbt-launch-lib.bash
+++ b/build/sbt-launch-lib.bash
@@ -45,9 +45,8 @@ dlog () {
acquire_sbt_jar () {
SBT_VERSION=`awk -F "=" '/sbt\.version/ {print $2}' ./project/build.properties`
- URL1=https://dl.bintray.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar
+ URL1=https://github.com/sbt/sbt/releases/download/v${SBT_VERSION}/sbt-${SBT_VERSION}.zip
JAR=build/sbt-launch-${SBT_VERSION}.jar
-
sbt_jar=$JAR
if [[ ! -f "$sbt_jar" ]]; then
@@ -55,13 +54,15 @@ acquire_sbt_jar () {
if [ ! -f "${JAR}" ]; then
# Download
printf "Attempting to fetch sbt\n"
- JAR_DL="${JAR}.part"
+ COMPLETE_SBT="build/sbt.zip"
if [ $(command -v curl) ]; then
- curl --fail --location --silent ${URL1} > "${JAR_DL}" &&\
- mv "${JAR_DL}" "${JAR}"
+ curl --fail --location --silent ${URL1} > "${COMPLETE_SBT}" &&\
+ unzip ${COMPLETE_SBT} &&\
+ cp "sbt/bin/sbt-launch.jar" "${JAR}"
elif [ $(command -v wget) ]; then
- wget --quiet ${URL1} -O "${JAR_DL}" &&\
- mv "${JAR_DL}" "${JAR}"
+ wget --quiet ${URL1} -O "${COMPLETE_SBT}" &&\
+ unzip ${COMPLETE_SBT} &&\
+ cp "sbt/bin/sbt-launch.jar" "${JAR}"
else
printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n"
exit -1
@@ -195,4 +196,4 @@ run() {
-jar "$sbt_jar" \
"${sbt_commands[@]}" \
"${residual_args[@]}"
-}
+}
\ No newline at end of file
diff --git a/project/build.properties b/project/build.properties
index 5c4bcd91..e88a0d81 100644
--- a/project/build.properties
+++ b/project/build.properties
@@ -1,2 +1 @@
-// This file should only contain the version of sbt to use.
-sbt.version=0.13.18
+sbt.version=1.10.6
diff --git a/project/plugins.sbt b/project/plugins.sbt
index d2473b61..cd448d78 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -1,17 +1,23 @@
// You may use this file to add plugin dependencies for sbt.
-resolvers += "Spark Packages repo" at "https://repos.spark-packages.org/"
+resolvers ++= Seq(
+ "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/",
+ "Spark Packages Repo" at "https://repos.spark-packages.org/"
+)
-resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/"
+// Remove incompatible plugins for now
+// addSbtPlugin("org.spark-packages" % "sbt-spark-package" % "0.2.3")
-addSbtPlugin("org.spark-packages" %% "sbt-spark-package" % "0.1.1")
+// addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0")
-addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0")
+// addSbtPlugin("com.github.sbt" % "sbt-release" % "1.0.15")
-addSbtPlugin("com.github.gseitz" % "sbt-release" % "1.0.0")
+// addSbtPlugin("com.databricks" %% "sbt-databricks" % "0.1.5")
-addSbtPlugin("com.databricks" %% "sbt-databricks" % "0.1.3")
+// addSbtPlugin("org.foundweekends" % "sbt-bintray" % "0.5.6")
-addSbtPlugin("me.lessis" % "bintray-sbt" % "0.3.0")
+// addSbtPlugin("com.github.sbt" % "sbt-pgp" % "2.1.2")
-addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.0.0")
+addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.3.1")
+
+addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.2")
diff --git a/src/main/scala/com/databricks/spark/sql/perf/AggregationPerformance.scala b/src/main/scala/com/databricks/spark/sql/perf/AggregationPerformance.scala
index 0ba3930a..48880753 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/AggregationPerformance.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/AggregationPerformance.scala
@@ -2,30 +2,30 @@ package com.databricks.spark.sql.perf
class AggregationPerformance extends Benchmark {
- import sqlContext.implicits._
+ import spark.implicits._
import ExecutionMode._
-
val sizes = (1 to 6).map(math.pow(10, _).toInt)
val x = Table(
"1milints", {
- val df = sqlContext.range(0, 1000000).repartition(1)
+ val df = spark.range(0, 1000000).repartition(1)
df.createTempView("1milints")
df
- })
+ }
+ )
val joinTables = Seq(
Table(
"100milints", {
- val df = sqlContext.range(0, 100000000).repartition(10)
+ val df = spark.range(0, 100000000).repartition(10)
df.createTempView("100milints")
df
- }),
-
+ }
+ ),
Table(
"1bilints", {
- val df = sqlContext.range(0, 1000000000).repartition(10)
+ val df = spark.range(0, 1000000000).repartition(10)
df.createTempView("1bilints")
df
}
@@ -33,28 +33,34 @@ class AggregationPerformance extends Benchmark {
)
val variousCardinality = sizes.map { size =>
- Table(s"ints$size", {
- val df = sparkContext.parallelize(1 to size).flatMap { group =>
- (1 to 10000).map(i => (group, i))
- }.toDF("a", "b")
- df.createTempView(s"ints$size")
- df
- })
+ Table(
+ s"ints$size", {
+ val df = spark.sparkContext
+ .parallelize(1 to size)
+ .flatMap { group =>
+ (1 to 10000).map(i => (group, i))
+ }
+ .toDF("a", "b")
+ df.createTempView(s"ints$size")
+ df
+ }
+ )
}
val lowCardinality = sizes.map { size =>
val fullSize = size * 10000L
Table(
s"twoGroups$fullSize", {
- val df = sqlContext.range(0, fullSize).select($"id" % 2 as 'a, $"id" as 'b)
+ val df = spark.range(0, fullSize).select($"id" % 2 as 'a, $"id" as 'b)
df.createTempView(s"twoGroups$fullSize")
df
- })
+ }
+ )
}
val newAggreation = Variation("aggregationType", Seq("new", "old")) {
- case "old" => sqlContext.setConf("spark.sql.useAggregate2", "false")
- case "new" => sqlContext.setConf("spark.sql.useAggregate2", "true")
+ case "old" => spark.conf.set("spark.sql.useAggregate2", "false")
+ case "new" => spark.conf.set("spark.sql.useAggregate2", "true")
}
val varyNumGroupsAvg: Seq[Benchmarkable] = variousCardinality.map(_.name).map { table =>
@@ -62,7 +68,8 @@ class AggregationPerformance extends Benchmark {
s"avg-$table",
s"SELECT AVG(b) FROM $table GROUP BY a",
"an average with a varying number of groups",
- executionMode = ForeachResults)
+ executionMode = ForeachResults
+ )
}
val twoGroupsAvg: Seq[Benchmarkable] = lowCardinality.map(_.name).map { table =>
@@ -70,7 +77,8 @@ class AggregationPerformance extends Benchmark {
s"avg-$table",
s"SELECT AVG(b) FROM $table GROUP BY a",
"an average on an int column with only two groups",
- executionMode = ForeachResults)
+ executionMode = ForeachResults
+ )
}
val complexInput: Seq[Benchmarkable] =
@@ -79,7 +87,8 @@ class AggregationPerformance extends Benchmark {
s"aggregation-complex-input-$table",
s"SELECT SUM(id + id + id + id + id + id + id + id + id + id) FROM $table",
"Sum of 9 columns added together",
- executionMode = CollectResults)
+ executionMode = CollectResults
+ )
}
val aggregates: Seq[Benchmarkable] =
@@ -89,7 +98,8 @@ class AggregationPerformance extends Benchmark {
s"single-aggregate-$agg-$table",
s"SELECT $agg(id) FROM $table",
"aggregation of a single column",
- executionMode = CollectResults)
+ executionMode = CollectResults
+ )
}
}
-}
\ No newline at end of file
+}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala b/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala
index ebb49353..3cd4249b 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala
@@ -21,97 +21,111 @@ import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent._
import scala.concurrent.duration._
import scala.language.implicitConversions
-import scala.util.{Success, Try, Failure => SFailure}
+import scala.util.{Failure => SFailure, Success, Try}
import scala.util.control.NonFatal
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Dataset, DataFrame, SQLContext, SparkSession}
+import org.apache.spark.sql.{DataFrame, Dataset, SQLContext, SparkSession}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.SparkContext
import com.databricks.spark.sql.perf.cpu._
-/**
- * A collection of queries that test a particular aspect of Spark SQL.
- *
- * @param sqlContext An existing SQLContext.
- */
-abstract class Benchmark(
- @transient val sqlContext: SQLContext)
- extends Serializable {
+/** A collection of queries that test a particular aspect of Spark SQL.
+ *
+ * @param sqlContext
+ * An existing SQLContext.
+ */
+abstract class Benchmark(@transient val sqlContext: SQLContext) extends Serializable {
+
+ @transient val spark = sqlContext.sparkSession
import Benchmark._
def this() = this(SparkSession.builder.getOrCreate().sqlContext)
val resultsLocation =
- sqlContext.getAllConfs.getOrElse(
- "spark.sql.perf.results",
- "/spark/sql/performance")
+ spark.conf.getAll.getOrElse("spark.sql.perf.results", "/spark/sql/performance")
protected def sparkContext = sqlContext.sparkContext
protected implicit def toOption[A](a: A): Option[A] = Option(a)
- val buildInfo = Try(getClass.getClassLoader.loadClass("org.apache.spark.BuildInfo")).map { cls =>
- cls.getMethods
- .filter(_.getReturnType == classOf[String])
+ val buildInfo = Try(getClass.getClassLoader.loadClass("org.apache.spark.BuildInfo"))
+ .map { cls =>
+ cls.getMethods
+ .filter(_.getReturnType == classOf[String])
.filterNot(_.getName == "toString")
.map(m => m.getName -> m.invoke(cls).asInstanceOf[String])
.toMap
- }.getOrElse(Map.empty)
+ }
+ .getOrElse(Map.empty)
def currentConfiguration = BenchmarkConfiguration(
- sqlConf = sqlContext.getAllConfs,
+ sqlConf = spark.conf.getAll,
sparkConf = sparkContext.getConf.getAll.toMap,
defaultParallelism = sparkContext.defaultParallelism,
- buildInfo = buildInfo)
-
+ buildInfo = buildInfo
+ )
val codegen = Variation("codegen", Seq("on", "off")) {
- case "off" => sqlContext.setConf("spark.sql.codegen", "false")
- case "on" => sqlContext.setConf("spark.sql.codegen", "true")
+ case "off" => spark.conf.set("spark.sql.codegen", "false")
+ case "on" => spark.conf.set("spark.sql.codegen", "true")
}
val unsafe = Variation("unsafe", Seq("on", "off")) {
- case "off" => sqlContext.setConf("spark.sql.unsafe.enabled", "false")
- case "on" => sqlContext.setConf("spark.sql.unsafe.enabled", "true")
+ case "off" => spark.conf.set("spark.sql.unsafe.enabled", "false")
+ case "on" => spark.conf.set("spark.sql.unsafe.enabled", "true")
}
val tungsten = Variation("tungsten", Seq("on", "off")) {
- case "off" => sqlContext.setConf("spark.sql.tungsten.enabled", "false")
- case "on" => sqlContext.setConf("spark.sql.tungsten.enabled", "true")
+ case "off" => spark.conf.set("spark.sql.tungsten.enabled", "false")
+ case "on" => spark.conf.set("spark.sql.tungsten.enabled", "true")
}
- /**
- * Starts an experiment run with a given set of executions to run.
- *
- * @param executionsToRun a list of executions to run.
- * @param includeBreakdown If it is true, breakdown results of an execution will be recorded.
- * Setting it to true may significantly increase the time used to
- * run an execution.
- * @param iterations The number of iterations to run of each execution.
- * @param variations [[Variation]]s used in this run. The cross product of all variations will be
- * run for each execution * iteration.
- * @param tags Tags of this run.
- * @param timeout wait at most timeout milliseconds for each query, 0 means wait forever
- * @return It returns a ExperimentStatus object that can be used to
- * track the progress of this experiment run.
- */
+ /** Starts an experiment run with a given set of executions to run.
+ *
+ * @param executionsToRun
+ * a list of executions to run.
+ * @param includeBreakdown
+ * If it is true, breakdown results of an execution will be recorded. Setting it to true may
+ * significantly increase the time used to run an execution.
+ * @param iterations
+ * The number of iterations to run of each execution.
+ * @param variations
+ * [[Variation]]s used in this run. The cross product of all variations will be run for each
+ * execution * iteration.
+ * @param tags
+ * Tags of this run.
+ * @param timeout
+ * wait at most timeout milliseconds for each query, 0 means wait forever
+ * @return
+ * It returns a ExperimentStatus object that can be used to track the progress of this
+ * experiment run.
+ */
def runExperiment(
executionsToRun: Seq[Benchmarkable],
includeBreakdown: Boolean = false,
iterations: Int = 3,
- variations: Seq[Variation[_]] = Seq(Variation("StandardRun", Seq("true")) { _ => {} }),
+ variations: Seq[Variation[_]] = Seq(Variation("StandardRun", Seq("true")) { _ => }),
tags: Map[String, String] = Map.empty,
timeout: Long = 0L,
resultLocation: String = resultsLocation,
- forkThread: Boolean = true) = {
-
- new ExperimentStatus(executionsToRun, includeBreakdown, iterations, variations, tags,
- timeout, resultLocation, sqlContext, allTables, currentConfiguration, forkThread = forkThread)
- }
-
+ forkThread: Boolean = true
+ ) =
+ new ExperimentStatus(
+ executionsToRun,
+ includeBreakdown,
+ iterations,
+ variations,
+ tags,
+ timeout,
+ resultLocation,
+ sqlContext,
+ allTables,
+ currentConfiguration,
+ forkThread = forkThread
+ )
import reflect.runtime._, universe._
import reflect.runtime._
@@ -135,7 +149,9 @@ abstract class Benchmark(
.filter(m => m.isMethod)
.map(_.asMethod)
.filter(_.asMethod.returnType =:= typeOf[Seq[Table]])
- .flatMap(method => runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Seq[Table]])
+ .flatMap(method =>
+ runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Seq[Table]]
+ )
@transient
lazy val allTables: Seq[Table] = (singleTables ++ groupedTables).toSeq
@@ -145,14 +161,18 @@ abstract class Benchmark(
.filter(m => m.isMethod)
.map(_.asMethod)
.filter(_.asMethod.returnType =:= typeOf[Benchmarkable])
- .map(method => runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Benchmarkable])
+ .map(method =>
+ runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Benchmarkable]
+ )
def groupedQueries =
myType.declarations
.filter(m => m.isMethod)
.map(_.asMethod)
.filter(_.asMethod.returnType =:= typeOf[Seq[Benchmarkable]])
- .flatMap(method => runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Seq[Benchmarkable]])
+ .flatMap(method =>
+ runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Seq[Benchmarkable]]
+ )
@transient
lazy val allQueries = (singleQueries ++ groupedQueries).toSeq
@@ -163,21 +183,25 @@ abstract class Benchmark(
.filter(m => m.isMethod)
.map(_.asMethod)
.filter(_.asMethod.returnType =:= typeOf[Query])
- .map(method => runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Query])
+ .map(method =>
+ runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Query]
+ )
.mkString(",")
val queries =
myType.declarations
- .filter(m => m.isMethod)
- .map(_.asMethod)
- .filter(_.asMethod.returnType =:= typeOf[Seq[Query]])
- .map { method =>
- val queries = runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Seq[Query]]
- val queryList = queries.map(_.name).mkString(", ")
- s"""
+ .filter(m => m.isMethod)
+ .map(_.asMethod)
+ .filter(_.asMethod.returnType =:= typeOf[Seq[Query]])
+ .map { method =>
+ val queries =
+ runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Seq[Query]]
+ val queryList = queries.map(_.name).mkString(", ")
+ s"""
|
${method.name}
|
""".stripMargin
- }.mkString("\n")
+ }
+ .mkString("\n")
s"""
|Spark SQL Performance Benchmarking
@@ -193,29 +217,17 @@ abstract class Benchmark(
name: String,
sqlText: String,
description: String,
- executionMode: ExecutionMode = ExecutionMode.ForeachResults): Query = {
- new Query(name, sqlContext.sql(sqlText), description, Some(sqlText), executionMode)
- }
+ executionMode: ExecutionMode = ExecutionMode.ForeachResults
+ ): Query =
+ new Query(name, spark.sql(sqlText), description, Some(sqlText), executionMode)
- def apply(
- name: String,
- dataFrameBuilder: => DataFrame,
- description: String): Query = {
+ def apply(name: String, dataFrameBuilder: => DataFrame, description: String): Query =
new Query(name, dataFrameBuilder, description, None, ExecutionMode.CollectResults)
- }
}
object RDDCount {
- def apply(
- name: String,
- rdd: RDD[_]) = {
- new SparkPerfExecution(
- name,
- Map.empty,
- () => Unit,
- () => rdd.count(),
- rdd.toDebugString)
- }
+ def apply(name: String, rdd: RDD[_]) =
+ new SparkPerfExecution(name, Map.empty, () => (), () => rdd.count(), rdd.toDebugString)
}
/** A class for benchmarking Spark perf results. */
@@ -224,8 +236,8 @@ abstract class Benchmark(
parameters: Map[String, String],
prepare: () => Unit,
run: () => Unit,
- description: String = "")
- extends Benchmarkable {
+ description: String = ""
+ ) extends Benchmarkable {
override def toString: String =
s"""
@@ -235,55 +247,52 @@ abstract class Benchmark(
protected override val executionMode: ExecutionMode = ExecutionMode.SparkPerfResults
- protected override def beforeBenchmark(): Unit = { prepare() }
+ protected override def beforeBenchmark(): Unit = prepare()
protected override def doBenchmark(
includeBreakdown: Boolean,
description: String = "",
- messages: ArrayBuffer[String]): BenchmarkResult = {
+ messages: ArrayBuffer[String],
+ iteration: Int = 1
+ ): BenchmarkResult =
try {
val timeMs = measureTimeMs(run())
BenchmarkResult(
name = name,
mode = executionMode.toString,
parameters = parameters,
- executionTime = Some(timeMs))
+ executionTime = Some(timeMs)
+ )
} catch {
case e: Exception =>
BenchmarkResult(
name = name,
mode = executionMode.toString,
parameters = parameters,
- failure = Some(Failure(e.getClass.getSimpleName, e.getMessage)))
+ failure = Some(Failure(e.getClass.getSimpleName, e.getMessage))
+ )
}
- }
}
}
-/**
- * A Variation represents a setting (e.g. the number of shuffle partitions or if tables
- * are cached in memory) that we want to change in a experiment run.
- * A Variation has three parts, `name`, `options`, and `setup`.
- * The `name` is the identifier of a Variation. `options` is a Seq of options that
- * will be used for a query. Basically, a query will be executed with every option
- * defined in the list of `options`. `setup` defines the needed action for every
- * option. For example, the following Variation is used to change the number of shuffle
- * partitions of a query. The name of the Variation is "shufflePartitions". There are
- * two options, 200 and 2000. The setup is used to set the value of property
- * "spark.sql.shuffle.partitions".
- *
- * {{{
- * Variation("shufflePartitions", Seq("200", "2000")) {
- * case num => sqlContext.setConf("spark.sql.shuffle.partitions", num)
- * }
- * }}}
- */
+/** A Variation represents a setting (e.g. the number of shuffle partitions or if tables are cached
+ * in memory) that we want to change in a experiment run. A Variation has three parts, `name`,
+ * `options`, and `setup`. The `name` is the identifier of a Variation. `options` is a Seq of
+ * options that will be used for a query. Basically, a query will be executed with every option
+ * defined in the list of `options`. `setup` defines the needed action for every option. For
+ * example, the following Variation is used to change the number of shuffle partitions of a query.
+ * The name of the Variation is "shufflePartitions". There are two options, 200 and 2000. The setup
+ * is used to set the value of property "spark.sql.shuffle.partitions".
+ *
+ * {{{
+ * Variation("shufflePartitions", Seq("200", "2000")) {
+ * case num => spark.conf.set("spark.sql.shuffle.partitions", num)
+ * }
+ * }}}
+ */
case class Variation[T](name: String, options: Seq[T])(val setup: T => Unit)
-case class Table(
- name: String,
- data: Dataset[_])
-
+case class Table(name: String, data: Dataset[_])
object Benchmark {
@@ -298,9 +307,11 @@ object Benchmark {
sqlContext: SQLContext,
allTables: Seq[Table],
currentConfiguration: BenchmarkConfiguration,
- forkThread: Boolean = true) {
- val currentResults = new collection.mutable.ArrayBuffer[BenchmarkResult]()
- val currentRuns = new collection.mutable.ArrayBuffer[ExperimentRun]()
+ forkThread: Boolean = true
+ ) {
+ val spark = sqlContext.sparkSession
+ val currentResults = new collection.mutable.ArrayBuffer[BenchmarkResult]()
+ val currentRuns = new collection.mutable.ArrayBuffer[ExperimentRun]()
val currentMessages = new collection.mutable.ArrayBuffer[String]()
def logMessage(msg: String) = {
@@ -310,22 +321,21 @@ object Benchmark {
// Stats for HTML status message.
@volatile var currentExecution = ""
- @volatile var currentPlan = "" // for queries only
- @volatile var currentConfig = ""
- @volatile var failures = 0
- @volatile var startTime = 0L
+ @volatile var currentPlan = "" // for queries only
+ @volatile var currentConfig = ""
+ @volatile var failures = 0
+ @volatile var startTime = 0L
/** An optional log collection task that will run after the experiment. */
@volatile var logCollection: () => Unit = () => {}
-
def cartesianProduct[T](xss: List[List[T]]): List[List[T]] = xss match {
- case Nil => List(Nil)
- case h :: t => for(xh <- h; xt <- cartesianProduct(t)) yield xh :: xt
+ case Nil => List(Nil)
+ case h :: t => for (xh <- h; xt <- cartesianProduct(t)) yield xh :: xt
}
- val timestamp = System.currentTimeMillis()
- val resultPath = s"$resultsLocation/timestamp=$timestamp"
+ val timestamp = System.currentTimeMillis()
+ val resultPath = s"$resultsLocation/timestamp=$timestamp"
val combinations = cartesianProduct(variations.map(l => (0 until l.options.size).toList).toList)
val resultsFuture = Future {
@@ -333,11 +343,11 @@ object Benchmark {
executionsToRun
.collect { case query: Query => query }
.flatMap { query =>
- try {
+ try
query.newDataFrame().queryExecution.logical.collect {
case r: UnresolvedRelation => r.tableName
}
- } catch {
+ catch {
// ignore the queries that can't be parsed
case e: Exception => Seq()
}
@@ -345,7 +355,7 @@ object Benchmark {
.distinct
.foreach { name =>
try {
- sqlContext.table(name)
+ spark.table(name)
logMessage(s"Table $name exists.")
} catch {
case ae: Exception =>
@@ -353,8 +363,7 @@ object Benchmark {
.find(_.name == name)
if (table.isDefined) {
logMessage(s"Creating table: $name")
- table.get.data
- .write
+ table.get.data.write
.mode("overwrite")
.saveAsTable(name)
} else {
@@ -372,18 +381,19 @@ object Benchmark {
v.setup(v.options(idx))
v.name -> v.options(idx).toString
}
- currentConfig = currentOptions.map { case (k,v) => s"$k: $v" }.mkString(", ")
+ currentConfig = currentOptions.map { case (k, v) => s"$k: $v" }.mkString(", ")
val res = executionsToRun.flatMap { q =>
- val setup = s"iteration: $i, ${currentOptions.map { case (k, v) => s"$k=$v"}.mkString(", ")}"
+ val setup =
+ s"iteration: $i, ${currentOptions.map { case (k, v) => s"$k=$v" }.mkString(", ")}"
logMessage(s"Running execution ${q.name} $setup")
currentExecution = q.name
currentPlan = q match {
case query: Query =>
- try {
+ try
query.newDataFrame().queryExecution.executedPlan.toString()
- } catch {
+ catch {
case e: Exception =>
s"failed to parse: $e"
}
@@ -392,8 +402,13 @@ object Benchmark {
startTime = System.currentTimeMillis()
val singleResultT = Try {
- q.benchmark(includeBreakdown, setup, currentMessages, timeout,
- forkThread=forkThread)
+ q.benchmark(
+ includeBreakdown,
+ setup,
+ currentMessages,
+ timeout,
+ forkThread = forkThread
+ )
}
singleResultT match {
@@ -409,7 +424,7 @@ object Benchmark {
singleResult :: Nil
case SFailure(e) =>
failures += 1
- logMessage(s"Execution '${q.name}' failed: ${e}")
+ logMessage(s"Execution '${q.name}' failed: $e")
Nil
}
}
@@ -419,7 +434,8 @@ object Benchmark {
iteration = i,
tags = currentOptions.toMap ++ tags,
configuration = currentConfiguration,
- res)
+ res
+ )
currentRuns += result
@@ -428,7 +444,7 @@ object Benchmark {
}
try {
- val resultsTable = sqlContext.createDataFrame(results)
+ val resultsTable = spark.createDataFrame(results)
logMessage(s"Results written to table: 'sqlPerformance' at $resultPath")
resultsTable
.coalesce(1)
@@ -444,7 +460,7 @@ object Benchmark {
logCollection()
}
- def scheduleCpuCollection(fs: FS) = {
+ def scheduleCpuCollection(fs: FS) =
logCollection = () => {
logMessage(s"Begining CPU log collection")
try {
@@ -456,40 +472,36 @@ object Benchmark {
throw e
}
}
- }
- def cpuProfile = new Profile(sqlContext, sqlContext.read.json(getCpuLocation(timestamp)))
+ def cpuProfile = new Profile(spark.sqlContext, spark.read.json(getCpuLocation(timestamp)))
- def cpuProfileHtml(fs: FS) = {
+ def cpuProfileHtml(fs: FS) =
s"""
|CPU Profile
|Permalink: sqlContext.read.json("${getCpuLocation(timestamp)}")
|${cpuProfile.buildGraph(fs)}
""".stripMargin
- }
/** Waits for the finish of the experiment. */
- def waitForFinish(timeoutInSeconds: Int) = {
+ def waitForFinish(timeoutInSeconds: Int) =
Await.result(resultsFuture, timeoutInSeconds.seconds)
- }
/** Returns results from an actively running experiment. */
def getCurrentResults() = {
- val tbl = sqlContext.createDataFrame(currentResults)
+ val tbl = spark.createDataFrame(currentResults)
tbl.createOrReplaceTempView("currentResults")
tbl
}
/** Returns full iterations from an actively running experiment. */
def getCurrentRuns() = {
- val tbl = sqlContext.createDataFrame(currentRuns)
+ val tbl = spark.createDataFrame(currentRuns)
tbl.createOrReplaceTempView("currentRuns")
tbl
}
- def tail(n: Int = 20) = {
+ def tail(n: Int = 20) =
currentMessages.takeRight(n).mkString("\n")
- }
def status =
if (resultsFuture.isCompleted) {
@@ -501,7 +513,6 @@ object Benchmark {
override def toString =
s"""Permalink: table("sqlPerformance").where('timestamp === ${timestamp}L)"""
-
def html: String = {
val maybeQueryPlan: String =
if (currentPlan.nonEmpty) {
@@ -516,7 +527,7 @@ object Benchmark {
}
s"""
|$status Experiment
- |Permalink: sqlContext.read.json("$resultPath")
+ |Permalink: spark.read.json("$resultPath")
|Iterations complete: ${currentRuns.size / combinations.size} / $iterations
|Failures: $failures
|Executions run: ${currentResults.size} / ${iterations * combinations.size * executionsToRun.size}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/Benchmarkable.scala b/src/main/scala/com/databricks/spark/sql/perf/Benchmarkable.scala
index 24efef70..3aacd315 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/Benchmarkable.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/Benchmarkable.scala
@@ -24,15 +24,14 @@ import scala.concurrent.duration._
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
-import org.apache.spark.sql.{SQLContext,SparkSession}
-import org.apache.spark.{SparkEnv, SparkContext}
-
+import org.apache.spark.sql.{SQLContext, SparkSession}
+import org.apache.spark.{SparkContext, SparkEnv}
/** A trait to describe things that can be benchmarked. */
trait Benchmarkable {
- @transient protected[this] val sqlSession = SparkSession.builder.getOrCreate()
- @transient protected[this] val sqlContext = sqlSession.sqlContext
- @transient protected[this] val sparkContext = sqlSession.sparkContext
+ @transient protected[this] val spark = SparkSession.builder.getOrCreate()
+ @transient protected[this] val sqlContext = spark.sqlContext
+ @transient protected[this] val sparkContext = spark.sparkContext
val name: String
protected val executionMode: ExecutionMode
@@ -43,40 +42,42 @@ trait Benchmarkable {
description: String = "",
messages: ArrayBuffer[String],
timeout: Long,
- forkThread: Boolean = true): BenchmarkResult = {
+ forkThread: Boolean = true,
+ iteration: Int = 1
+ ): BenchmarkResult = {
logger.info(s"$this: benchmark")
sparkContext.setJobDescription(s"Execution: $name, $description")
beforeBenchmark()
val result = if (forkThread) {
runBenchmarkForked(includeBreakdown, description, messages, timeout)
} else {
- doBenchmark(includeBreakdown, description, messages)
+ doBenchmark(includeBreakdown, description, messages, iteration)
}
afterBenchmark(sqlContext.sparkContext)
result
}
- protected def beforeBenchmark(): Unit = { }
+ protected def beforeBenchmark(): Unit = {}
- protected def afterBenchmark(sc: SparkContext): Unit = {
+ protected def afterBenchmark(sc: SparkContext): Unit =
System.gc()
- }
private def runBenchmarkForked(
includeBreakdown: Boolean,
description: String = "",
messages: ArrayBuffer[String],
- timeout: Long): BenchmarkResult = {
- val jobgroup = UUID.randomUUID().toString
- val that = this
+ timeout: Long
+ ): BenchmarkResult = {
+ val jobgroup = UUID.randomUUID().toString
+ val that = this
var result: BenchmarkResult = null
val thread = new Thread("benchmark runner") {
override def run(): Unit = {
logger.info(s"$that running $this")
sparkContext.setJobGroup(jobgroup, s"benchmark $name", true)
- try {
+ try
result = doBenchmark(includeBreakdown, description, messages)
- } catch {
+ catch {
case e: Throwable =>
logger.info(s"$that: failure in runBenchmark: $e")
println(s"$that: failure in runBenchmark: $e")
@@ -84,8 +85,13 @@ trait Benchmarkable {
name = name,
mode = executionMode.toString,
parameters = Map.empty,
- failure = Some(Failure(e.getClass.getSimpleName,
- e.getMessage + ":\n" + e.getStackTraceString)))
+ failure = Some(
+ Failure(
+ e.getClass.getSimpleName,
+ e.getMessage + ":\n" + e.getStackTrace.mkString("\n")
+ )
+ )
+ )
}
}
}
@@ -107,7 +113,9 @@ trait Benchmarkable {
protected def doBenchmark(
includeBreakdown: Boolean,
description: String = "",
- messages: ArrayBuffer[String]): BenchmarkResult
+ messages: ArrayBuffer[String],
+ iteration: Int = 1
+ ): BenchmarkResult
protected def measureTimeMs[A](f: => A): Double = {
val startTime = System.nanoTime()
@@ -118,8 +126,8 @@ trait Benchmarkable {
protected def measureTime[A](f: => A): (Duration, A) = {
val startTime = System.nanoTime()
- val res = f
- val endTime = System.nanoTime()
+ val res = f
+ val endTime = System.nanoTime()
(endTime - startTime).nanos -> res
}
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/CpuProfile.scala b/src/main/scala/com/databricks/spark/sql/perf/CpuProfile.scala
index 901563a2..bdf493f9 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/CpuProfile.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/CpuProfile.scala
@@ -16,10 +16,10 @@
package com.databricks.spark.sql.perf
-import java.io.{FileOutputStream, File}
+import java.io.{File, FileOutputStream}
import org.apache.hadoop.conf.Configuration
-import org.apache.spark.sql.{DataFrame, SQLContext, Row}
+import org.apache.spark.sql.{DataFrame, Row, SQLContext, SparkSession}
import org.apache.spark.sql.functions._
import scala.language.reflectiveCalls
@@ -29,10 +29,9 @@ import org.apache.hadoop.fs.{FileSystem, Path}
import com.twitter.jvm.CpuProfile
-/**
- * A collection of utilities for parsing stacktraces that have been recorded in JSON and generating visualizations
- * on where time is being spent.
- */
+/** A collection of utilities for parsing stacktraces that have been recorded in JSON and generating
+ * visualizations on where time is being spent.
+ */
package object cpu {
// Placeholder for DBFS.
@@ -44,10 +43,7 @@ package object cpu {
private val resultsLocation = "/spark/sql/cpu"
lazy val pprof = {
- run(
- "sudo apt-get install -y graphviz",
- "cp /dbfs/home/michael/pprof ./",
- "chmod 755 pprof")
+ run("sudo apt-get install -y graphviz", "cp /dbfs/home/michael/pprof ./", "chmod 755 pprof")
"./pprof"
}
@@ -55,23 +51,27 @@ package object cpu {
def getCpuLocation(timestamp: Long) = s"$resultsLocation/timestamp=$timestamp"
def collectLogs(sqlContext: SQLContext, fs: FS, timestamp: Long): String = {
- import sqlContext.implicits._
+ val spark = sqlContext.sparkSession
+ import spark.implicits._
def sc = sqlContext.sparkContext
def copyLogFiles() = {
- val path = "pwd".!!.trim
+ val path = "pwd".!!.trim
val hostname = "hostname".!!.trim
val conf = new Configuration()
- val fs = FileSystem.get(conf)
- fs.copyFromLocalFile(new Path(s"$path/logs/cpu.json"), new Path(s"$resultsLocation/timestamp=$timestamp/$hostname"))
+ val fs = FileSystem.get(conf)
+ fs.copyFromLocalFile(
+ new Path(s"$path/logs/cpu.json"),
+ new Path(s"$resultsLocation/timestamp=$timestamp/$hostname")
+ )
}
fs.rm(getCpuLocation(timestamp), true)
copyLogFiles()
- sc.parallelize((1 to 100)).foreach { i => copyLogFiles() }
+ sc.parallelize((1 to 100)).foreach(i => copyLogFiles())
getCpuLocation(timestamp)
}
@@ -92,7 +92,8 @@ package object cpu {
}
class Profile(private val sqlContext: SQLContext, cpuLogs: DataFrame) {
- import sqlContext.implicits._
+ val spark = sqlContext.sparkSession
+ import spark.implicits._
def hosts = cpuLogs.select($"tags.hostName").distinct.collect().map(_.getString(0))
@@ -100,24 +101,39 @@ package object cpu {
val stackLine = """(.*)\.([^\(]+)\(([^:]+)(:{0,1}\d*)\)""".r
def toStackElement(s: String) = s match {
case stackLine(cls, method, file, "") => new StackTraceElement(cls, method, file, 0)
- case stackLine(cls, method, file, line) => new StackTraceElement(cls, method, file, line.stripPrefix(":").toInt)
+ case stackLine(cls, method, file, line) =>
+ new StackTraceElement(cls, method, file, line.stripPrefix(":").toInt)
}
- val counts = cpuLogs.groupBy($"stack").agg(count($"*")).collect().flatMap {
- case Row(stackLines: Array[String], count: Long) => stackLines.toSeq.map(toStackElement) -> count :: Nil
- case other => println(s"Failed to parse $other"); Nil
- }.toMap
- val profile = new com.twitter.jvm.CpuProfile(counts, com.twitter.util.Duration.fromSeconds(10), cpuLogs.count().toInt, 0)
+ val counts = cpuLogs
+ .groupBy($"stack")
+ .agg(count($"*"))
+ .collect()
+ .flatMap {
+ case Row(stackLines: Array[String], count: Long) =>
+ stackLines.toSeq.map(toStackElement) -> count :: Nil
+ case other => println(s"Failed to parse $other"); Nil
+ }
+ .toMap
+ val profile = new com.twitter.jvm.CpuProfile(
+ counts,
+ com.twitter.util.Duration.fromSeconds(10),
+ cpuLogs.count().toInt,
+ 0
+ )
val outfile = File.createTempFile("cpu", "profile")
val svgFile = File.createTempFile("cpu", "svg")
profile.writeGoogleProfile(new FileOutputStream(outfile))
- println(run(
- "cp /dbfs/home/michael/pprof ./",
- "chmod 755 pprof",
- s"$pprof --svg ${outfile.getCanonicalPath} > ${svgFile.getCanonicalPath}"))
+ println(
+ run(
+ "cp /dbfs/home/michael/pprof ./",
+ "chmod 755 pprof",
+ s"$pprof --svg ${outfile.getCanonicalPath} > ${svgFile.getCanonicalPath}"
+ )
+ )
val timestamp = System.currentTimeMillis()
fs.cp(s"file://$svgFile", s"/FileStore/cpu.profiles/$timestamp.svg", false)
diff --git a/src/main/scala/com/databricks/spark/sql/perf/DatasetPerformance.scala b/src/main/scala/com/databricks/spark/sql/perf/DatasetPerformance.scala
index 0aaa6296..c25c6df0 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/DatasetPerformance.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/DatasetPerformance.scala
@@ -28,7 +28,7 @@ object TypedAverage extends Aggregator[Long, SumAndCount, Double] {
b
}
- override def bufferEncoder = Encoders.product
+ override def bufferEncoder = Encoders.product
override def outputEncoder = Encoders.scalaDouble
@@ -47,30 +47,22 @@ case class SumAndCount(var sum: Long, var count: Int)
class DatasetPerformance extends Benchmark {
- import sqlContext.implicits._
+ import spark.implicits._
val numLongs = 100000000
- val ds = sqlContext.range(1, numLongs)
- val rdd = sparkContext.range(1, numLongs)
+ val ds = spark.range(1, numLongs)
+ val rdd = spark.sparkContext.range(1, numLongs)
val smallNumLongs = 1000000
- val smallds = sqlContext.range(1, smallNumLongs).as[Long]
- val smallrdd = sparkContext.range(1, smallNumLongs)
+ val smallds = spark.range(1, smallNumLongs).as[Long]
+ val smallrdd = spark.sparkContext.range(1, smallNumLongs)
- def allBenchmarks = range ++ backToBackFilters ++ backToBackMaps ++ computeAverage
+ def allBenchmarks = range ++ backToBackFilters ++ backToBackMaps ++ computeAverage
val range = Seq(
- new Query(
- "DS: range",
- ds.as[Data].toDF(),
- executionMode = ExecutionMode.ForeachResults),
- new Query(
- "DF: range",
- ds.toDF(),
- executionMode = ExecutionMode.ForeachResults),
- RDDCount(
- "RDD: range",
- rdd.map(Data(_)))
+ new Query("DS: range", ds.as[Data].toDF(), executionMode = ExecutionMode.ForeachResults),
+ new Query("DF: range", ds.toDF(), executionMode = ExecutionMode.ForeachResults),
+ RDDCount("RDD: range", rdd.map(Data(_)))
)
val backToBackFilters = Seq(
@@ -80,21 +72,26 @@ class DatasetPerformance extends Benchmark {
.filter(_.id % 100 != 0)
.filter(_.id % 101 != 0)
.filter(_.id % 102 != 0)
- .filter(_.id % 103 != 0).toDF()),
+ .filter(_.id % 103 != 0)
+ .toDF()
+ ),
new Query(
"DF: back-to-back filters",
ds.toDF()
.filter("id % 100 != 0")
.filter("id % 101 != 0")
.filter("id % 102 != 0")
- .filter("id % 103 != 0")),
+ .filter("id % 103 != 0")
+ ),
RDDCount(
"RDD: back-to-back filters",
- rdd.map(Data(_))
+ rdd
+ .map(Data(_))
.filter(_.id % 100 != 0)
.filter(_.id % 101 != 0)
.filter(_.id % 102 != 0)
- .filter(_.id % 103 != 0))
+ .filter(_.id % 103 != 0)
+ )
)
val backToBackMaps = Seq(
@@ -104,40 +101,48 @@ class DatasetPerformance extends Benchmark {
.map(d => Data(d.id + 1L))
.map(d => Data(d.id + 1L))
.map(d => Data(d.id + 1L))
- .map(d => Data(d.id + 1L)).toDF()),
+ .map(d => Data(d.id + 1L))
+ .toDF()
+ ),
new Query(
"DF: back-to-back maps",
ds.toDF()
.select($"id" + 1 as 'id)
.select($"id" + 1 as 'id)
.select($"id" + 1 as 'id)
- .select($"id" + 1 as 'id)),
+ .select($"id" + 1 as 'id)
+ ),
RDDCount(
"RDD: back-to-back maps",
- rdd.map(Data)
+ rdd
+ .map(Data)
+ .map(d => Data(d.id + 1L))
.map(d => Data(d.id + 1L))
.map(d => Data(d.id + 1L))
.map(d => Data(d.id + 1L))
- .map(d => Data(d.id + 1L)))
+ )
)
val computeAverage = Seq(
new Query(
"DS: average",
smallds.select(TypedAverage.toColumn).toDF(),
- executionMode = ExecutionMode.CollectResults),
+ executionMode = ExecutionMode.CollectResults
+ ),
new Query(
"DF: average",
smallds.toDF().selectExpr("avg(id)"),
- executionMode = ExecutionMode.CollectResults),
+ executionMode = ExecutionMode.CollectResults
+ ),
new SparkPerfExecution(
"RDD: average",
Map.empty,
- prepare = () => Unit,
+ prepare = () => (),
run = () => {
val sumAndCount =
smallrdd.map(i => (i, 1)).reduce((a, b) => (a._1 + b._1, a._2 + b._2))
sumAndCount._1.toDouble / sumAndCount._2
- })
+ }
+ )
)
-}
\ No newline at end of file
+}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/ExecutionMode.scala b/src/main/scala/com/databricks/spark/sql/perf/ExecutionMode.scala
index e44bd87c..dd7f8e5d 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/ExecutionMode.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/ExecutionMode.scala
@@ -16,12 +16,12 @@
package com.databricks.spark.sql.perf
-/**
- * Describes how a given Spark benchmark should be run (i.e. should the results be collected to
- * the driver or just computed on the executors.
- */
+/** Describes how a given Spark benchmark should be run (i.e. should the results be collected to the
+ * driver or just computed on the executors.
+ */
trait ExecutionMode extends Serializable
case object ExecutionMode {
+
/** Benchmark run by collecting queries results (e.g. rdd.collect()) */
case object CollectResults extends ExecutionMode {
override def toString: String = "collect"
@@ -37,10 +37,9 @@ case object ExecutionMode {
override def toString: String = "saveToParquet"
}
- /**
- * Benchmark run by calculating the sum of the hash value of all rows. This is used to check
- * query results do not change.
- */
+ /** Benchmark run by calculating the sum of the hash value of all rows. This is used to check
+ * query results do not change.
+ */
case object HashResults extends ExecutionMode {
override def toString: String = "hash"
}
@@ -49,4 +48,4 @@ case object ExecutionMode {
case object SparkPerfResults extends ExecutionMode {
override def toString: String = "sparkPerf"
}
-}
\ No newline at end of file
+}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/JoinPerformance.scala b/src/main/scala/com/databricks/spark/sql/perf/JoinPerformance.scala
index 8c587066..c27a0d3a 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/JoinPerformance.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/JoinPerformance.scala
@@ -5,30 +5,30 @@ import org.apache.spark.sql.types._
class JoinPerformance extends Benchmark {
-
import ExecutionMode._
- import sqlContext.implicits._
+ import spark.implicits._
- private val table = sqlContext.table _
+ private val table = (s: String) => spark.table(s)
val x = Table(
- "1milints", { // 1.5 mb, 1 file
- val df = sqlContext.range(0, 1000000).repartition(1)
+ "1milints", { // 1.5 mb, 1 file
+ val df = spark.range(0, 1000000).repartition(1)
df.createTempView("1milints")
df
- })
+ }
+ )
val joinTables = Seq(
Table(
- "100milints", { // 143.542mb, 10 files
- val df = sqlContext.range(0, 100000000).repartition(10)
+ "100milints", { // 143.542mb, 10 files
+ val df = spark.range(0, 100000000).repartition(10)
df.createTempView("100milints")
df
- }),
-
+ }
+ ),
Table(
- "1bilints", { // 143.542mb, 10 files
- val df = sqlContext.range(0, 1000000000).repartition(10)
+ "1bilints", { // 143.542mb, 10 files
+ val df = spark.range(0, 1000000000).repartition(10)
df.createTempView("1bilints")
df
}
@@ -36,41 +36,46 @@ class JoinPerformance extends Benchmark {
)
val sortMergeJoin = Variation("sortMergeJoin", Seq("on", "off")) {
- case "off" => sqlContext.setConf("spark.sql.planner.sortMergeJoin", "false")
- case "on" => sqlContext.setConf("spark.sql.planner.sortMergeJoin", "true")
+ case "off" => spark.conf.set("spark.sql.planner.sortMergeJoin", "false")
+ case "on" => spark.conf.set("spark.sql.planner.sortMergeJoin", "true")
}
- val singleKeyJoins: Seq[Benchmarkable] = Seq("1milints", "100milints", "1bilints").flatMap { table1 =>
- Seq("1milints", "100milints", "1bilints").flatMap { table2 =>
- Seq("JOIN", "RIGHT JOIN", "LEFT JOIN", "FULL OUTER JOIN").map { join =>
- Query(
- s"singleKey-$join-$table1-$table2",
- s"SELECT COUNT(*) FROM $table1 a $join $table2 b ON a.id = b.id",
- "equi-inner join a small table with a big table using a single key.",
- executionMode = CollectResults)
+ val singleKeyJoins: Seq[Benchmarkable] = Seq("1milints", "100milints", "1bilints").flatMap {
+ table1 =>
+ Seq("1milints", "100milints", "1bilints").flatMap { table2 =>
+ Seq("JOIN", "RIGHT JOIN", "LEFT JOIN", "FULL OUTER JOIN").map { join =>
+ Query(
+ s"singleKey-$join-$table1-$table2",
+ s"SELECT COUNT(*) FROM $table1 a $join $table2 b ON a.id = b.id",
+ "equi-inner join a small table with a big table using a single key.",
+ executionMode = CollectResults
+ )
+ }
}
- }
}
val varyDataSize = Seq(1, 128, 256, 512, 1024).map { dataSize =>
val intsWithData = table("100milints").select($"id", lit("*" * dataSize).as(s"data$dataSize"))
new Query(
s"join - datasize: $dataSize",
- intsWithData.as("a").join(intsWithData.as("b"), $"a.id" === $"b.id"))
+ intsWithData.as("a").join(intsWithData.as("b"), $"a.id" === $"b.id")
+ )
}
val varyKeyType = Seq(StringType, IntegerType, LongType, DoubleType).map { keyType =>
val convertedInts = table("100milints").select($"id".cast(keyType).as("id"))
new Query(
s"join - keytype: $keyType",
- convertedInts.as("a").join(convertedInts.as("b"), $"a.id" === $"b.id"))
+ convertedInts.as("a").join(convertedInts.as("b"), $"a.id" === $"b.id")
+ )
}
val varyNumMatches = Seq(1, 2, 4, 8, 16).map { numCopies =>
- val ints = table("100milints")
+ val ints = table("100milints")
val copiedInts = Seq.fill(numCopies)(ints).reduce(_ union _)
new Query(
s"join - numMatches: $numCopies",
- copiedInts.as("a").join(ints.as("b"), $"a.id" === $"b.id"))
+ copiedInts.as("a").join(ints.as("b"), $"a.id" === $"b.id")
+ )
}
-}
\ No newline at end of file
+}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/Query.scala b/src/main/scala/com/databricks/spark/sql/perf/Query.scala
index babc63f0..4339fb4e 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/Query.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/Query.scala
@@ -24,25 +24,25 @@ import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.execution.SparkPlan
-
/** Holds one benchmark query and its metadata. */
class Query(
override val name: String,
buildDataFrame: => DataFrame,
val description: String = "",
val sqlText: Option[String] = None,
- override val executionMode: ExecutionMode = ExecutionMode.ForeachResults)
- extends Benchmarkable with Serializable {
+ override val executionMode: ExecutionMode = ExecutionMode.ForeachResults
+) extends Benchmarkable
+ with Serializable {
private implicit def toOption[A](a: A): Option[A] = Option(a)
- override def toString: String = {
- try {
+ override def toString: String =
+ try
s"""
|== Query: $name ==
|${buildDataFrame.queryExecution.analyzed}
""".stripMargin
- } catch {
+ catch {
case e: Exception =>
s"""
|== Query: $name ==
@@ -51,7 +51,6 @@ class Query(
| $description
""".stripMargin
}
- }
lazy val tablesInvolved = buildDataFrame.queryExecution.logical collect {
case r: UnresolvedRelation => r.tableName
@@ -62,9 +61,11 @@ class Query(
protected override def doBenchmark(
includeBreakdown: Boolean,
description: String = "",
- messages: ArrayBuffer[String]): BenchmarkResult = {
+ messages: ArrayBuffer[String],
+ iteration: Int = 1
+ ): BenchmarkResult =
try {
- val dataFrame = buildDataFrame
+ val dataFrame = buildDataFrame
val queryExecution = dataFrame.queryExecution
// We are not counting the time of ScalaReflection.convertRowToScala.
val parsingTime = measureTimeMs {
@@ -81,23 +82,23 @@ class Query(
}
val breakdownResults = if (includeBreakdown) {
- val depth = queryExecution.executedPlan.collect { case p: SparkPlan => p }.size
+ val depth = queryExecution.executedPlan.collect { case p: SparkPlan => p }.size
val physicalOperators = (0 until depth).map(i => (i, queryExecution.executedPlan.p(i)))
- val indexMap = physicalOperators.map { case (index, op) => (op, index) }.toMap
- val timeMap = new mutable.HashMap[Int, Double]
- val maxFields = 999 // Maximum number of fields that will be converted to strings
+ val indexMap = physicalOperators.map { case (index, op) => (op, index) }.toMap
+ val timeMap = new mutable.HashMap[Int, Double]
+ val maxFields = 999 // Maximum number of fields that will be converted to strings
physicalOperators.reverse.map {
case (index, node) =>
messages += s"Breakdown: ${node.simpleString(maxFields)}"
val newNode = buildDataFrame.queryExecution.executedPlan.p(index)
val executionTime = measureTimeMs {
- newNode.execute().foreach((row: Any) => Unit)
+ newNode.execute().foreach((row: Any) => ())
}
timeMap += ((index, executionTime))
val childIndexes = node.children.map(indexMap)
- val childTime = childIndexes.map(timeMap).sum
+ val childTime = childIndexes.map(timeMap).sum
messages += s"Breakdown time: $executionTime (+${executionTime - childTime})"
BreakdownResult(
@@ -106,7 +107,8 @@ class Query(
index,
childIndexes,
executionTime,
- executionTime - childTime)
+ executionTime - childTime
+ )
}
} else {
Seq.empty[BreakdownResult]
@@ -120,7 +122,7 @@ class Query(
val executionTime = measureTimeMs {
executionMode match {
case ExecutionMode.CollectResults => dataFrame.collect()
- case ExecutionMode.ForeachResults => dataFrame.foreach { _ => ():Unit }
+ case ExecutionMode.ForeachResults => dataFrame.foreach(_ => (): Unit)
case ExecutionMode.WriteParquet(location) =>
dataFrame.write.parquet(s"$location/$name.parquet")
case ExecutionMode.HashResults =>
@@ -149,18 +151,20 @@ class Query(
executionTime = executionTime,
result = result,
queryExecution = dataFrame.queryExecution.toString,
- breakDown = breakdownResults)
+ breakDown = breakdownResults
+ )
} catch {
case e: Exception =>
- BenchmarkResult(
- name = name,
- mode = executionMode.toString,
- failure = Failure(e.getClass.getName, e.getMessage))
+ BenchmarkResult(
+ name = name,
+ mode = executionMode.toString,
+ failure = Failure(e.getClass.getName, e.getMessage)
+ )
}
- }
- /** Change the ExecutionMode of this Query to HashResults, which is used to check the query result. */
- def checkResult: Query = {
+ /** Change the ExecutionMode of this Query to HashResults, which is used to check the query
+ * result.
+ */
+ def checkResult: Query =
new Query(name, buildDataFrame, description, sqlText, ExecutionMode.HashResults)
- }
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/RunBenchmark.scala b/src/main/scala/com/databricks/spark/sql/perf/RunBenchmark.scala
index ed367e7f..857f4df4 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/RunBenchmark.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/RunBenchmark.scala
@@ -20,7 +20,7 @@ import java.net.InetAddress
import java.io.File
import org.apache.spark.sql.{SQLContext, SparkSession}
import org.apache.spark.sql.functions._
-import org.apache.spark.{SparkContext, SparkConf}
+import org.apache.spark.{SparkConf, SparkContext}
import scala.util.Try
case class RunConfig(
@@ -28,20 +28,20 @@ case class RunConfig(
benchmarkName: String = null,
filter: Option[String] = None,
iterations: Int = 3,
- baseline: Option[Long] = None)
+ baseline: Option[Long] = None
+)
-/**
- * Runs a benchmark locally and prints the results to the screen.
- */
+/** Runs a benchmark locally and prints the results to the screen.
+ */
object RunBenchmark {
def main(args: Array[String]): Unit = {
val parser = new scopt.OptionParser[RunConfig]("spark-sql-perf") {
head("spark-sql-perf", "0.2.0")
opt[String]('m', "master")
- .action { (x, c) => c.copy(master = x) }
+ .action((x, c) => c.copy(master = x))
.text("the Spark master to use, default to local[*]")
opt[String]('b', "benchmark")
- .action { (x, c) => c.copy(benchmarkName = x) }
+ .action((x, c) => c.copy(benchmarkName = x))
.text("the name of the benchmark to run")
.required()
opt[String]('f', "filter")
@@ -51,8 +51,8 @@ object RunBenchmark {
.action((x, c) => c.copy(iterations = x))
.text("the number of iterations to run")
opt[Long]('c', "compare")
- .action((x, c) => c.copy(baseline = Some(x)))
- .text("the timestamp of the baseline experiment to compare with")
+ .action((x, c) => c.copy(baseline = Some(x)))
+ .text("the timestamp of the baseline experiment to compare with")
help("help")
.text("prints this usage text")
}
@@ -71,21 +71,22 @@ object RunBenchmark {
.setAppName(getClass.getName)
val sparkSession = SparkSession.builder.config(conf).getOrCreate()
- val sc = sparkSession.sparkContext
- val sqlContext = sparkSession.sqlContext
- import sqlContext.implicits._
+ val sc = sparkSession.sparkContext
+ val sqlContext = sparkSession.sqlContext
+ import sparkSession.implicits._
- sqlContext.setConf("spark.sql.perf.results",
- new File("performance").toURI.toString)
+ sparkSession.conf.set("spark.sql.perf.results", new File("performance").toURI.toString)
val benchmark = Try {
- Class.forName(config.benchmarkName)
- .newInstance()
- .asInstanceOf[Benchmark]
+ Class
+ .forName(config.benchmarkName)
+ .newInstance()
+ .asInstanceOf[Benchmark]
} getOrElse {
- Class.forName("com.databricks.spark.sql.perf." + config.benchmarkName)
- .newInstance()
- .asInstanceOf[Benchmark]
+ Class
+ .forName("com.databricks.spark.sql.perf." + config.benchmarkName)
+ .newInstance()
+ .asInstanceOf[Benchmark]
}
val allQueries = config.filter.map { f =>
@@ -100,49 +101,55 @@ object RunBenchmark {
val experiment = benchmark.runExperiment(
executionsToRun = allQueries,
iterations = config.iterations,
- tags = Map(
- "runtype" -> "local",
- "host" -> InetAddress.getLocalHost().getHostName()))
+ tags = Map("runtype" -> "local", "host" -> InetAddress.getLocalHost().getHostName())
+ )
println("== STARTING EXPERIMENT ==")
experiment.waitForFinish(1000 * 60 * 30)
- sqlContext.setConf("spark.sql.shuffle.partitions", "1")
-
- val toShow = experiment.getCurrentRuns()
- .withColumn("result", explode($"results"))
- .select("result.*")
- .groupBy("name")
- .agg(
- min($"executionTime") as 'minTimeMs,
- max($"executionTime") as 'maxTimeMs,
- avg($"executionTime") as 'avgTimeMs,
- stddev($"executionTime") as 'stdDev,
- (stddev($"executionTime") / avg($"executionTime") * 100) as 'stdDevPercent)
- .orderBy("name")
-
+ sparkSession.conf.set("spark.sql.shuffle.partitions", "1")
+
+ val toShow = experiment
+ .getCurrentRuns()
+ .withColumn("result", explode($"results"))
+ .select("result.*")
+ .groupBy("name")
+ .agg(
+ min($"executionTime") as 'minTimeMs,
+ max($"executionTime") as 'maxTimeMs,
+ avg($"executionTime") as 'avgTimeMs,
+ stddev($"executionTime") as 'stdDev,
+ (stddev($"executionTime") / avg($"executionTime") * 100) as 'stdDevPercent
+ )
+ .orderBy("name")
+
println("Showing at most 100 query results now")
toShow.show(100)
-
+
println(s"""Results: sqlContext.read.json("${experiment.resultPath}")""")
config.baseline.foreach { baseTimestamp =>
val baselineTime = when($"timestamp" === baseTimestamp, $"executionTime").otherwise(null)
- val thisRunTime = when($"timestamp" === experiment.timestamp, $"executionTime").otherwise(null)
-
- val data = sqlContext.read.json(benchmark.resultsLocation)
- .coalesce(1)
- .where(s"timestamp IN ($baseTimestamp, ${experiment.timestamp})")
- .withColumn("result", explode($"results"))
- .select("timestamp", "result.*")
- .groupBy("name")
- .agg(
- avg(baselineTime) as 'baselineTimeMs,
- avg(thisRunTime) as 'thisRunTimeMs,
- stddev(baselineTime) as 'stddev)
- .withColumn(
- "percentChange", ($"baselineTimeMs" - $"thisRunTimeMs") / $"baselineTimeMs" * 100)
- .filter('thisRunTimeMs.isNotNull)
+ val thisRunTime =
+ when($"timestamp" === experiment.timestamp, $"executionTime").otherwise(null)
+
+ val data = sparkSession.read
+ .json(benchmark.resultsLocation)
+ .coalesce(1)
+ .where(s"timestamp IN ($baseTimestamp, ${experiment.timestamp})")
+ .withColumn("result", explode($"results"))
+ .select("timestamp", "result.*")
+ .groupBy("name")
+ .agg(
+ avg(baselineTime) as 'baselineTimeMs,
+ avg(thisRunTime) as 'thisRunTimeMs,
+ stddev(baselineTime) as 'stddev
+ )
+ .withColumn(
+ "percentChange",
+ ($"baselineTimeMs" - $"thisRunTimeMs") / $"baselineTimeMs" * 100
+ )
+ .filter('thisRunTimeMs.isNotNull)
data.show(truncate = false)
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/Tables.scala b/src/main/scala/com/databricks/spark/sql/perf/Tables.scala
index 177d38ce..e76ec372 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/Tables.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/Tables.scala
@@ -29,21 +29,19 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SQLContext, SaveMode}
-
-/**
- * Using ProcessBuilder.lineStream produces a stream, that uses
- * a LinkedBlockingQueue with a default capacity of Integer.MAX_VALUE.
- *
- * This causes OOM if the consumer cannot keep up with the producer.
- *
- * See scala.sys.process.ProcessBuilderImpl.lineStream
- */
+/** Using ProcessBuilder.lineStream produces a stream, that uses a LinkedBlockingQueue with a
+ * default capacity of Integer.MAX_VALUE.
+ *
+ * This causes OOM if the consumer cannot keep up with the producer.
+ *
+ * See scala.sys.process.ProcessBuilderImpl.lineStream
+ */
object BlockingLineStream {
// See scala.sys.process.Streamed
private final class BlockingStreamed[T](
- val process: T => Unit,
- val done: Int => Unit,
- val stream: () => Stream[T]
+ val process: T => Unit,
+ val done: Int => Unit,
+ val stream: () => Stream[T]
)
// See scala.sys.process.Streamed
@@ -70,7 +68,7 @@ object BlockingLineStream {
private object Spawn {
def apply(f: => Unit): Thread = apply(f, daemon = false)
def apply(f: => Unit, daemon: Boolean): Thread = {
- val thread = new Thread() { override def run() = { f } }
+ val thread = new Thread() { override def run() = f }
thread.setDaemon(daemon)
thread.start()
thread
@@ -79,7 +77,7 @@ object BlockingLineStream {
def apply(command: Seq[String]): Stream[String] = {
val streamed = BlockingStreamed[String](true)
- val process = command.run(BasicIO(false, streamed.process, None))
+ val process = command.run(BasicIO(false, streamed.process, None))
Spawn(streamed.done(process.exitValue()))
streamed.stream()
}
@@ -87,16 +85,19 @@ object BlockingLineStream {
trait DataGenerator extends Serializable {
def generate(
- sparkContext: SparkContext,
- name: String,
- partitions: Int,
- scaleFactor: String): RDD[String]
+ sparkContext: SparkContext,
+ name: String,
+ partitions: Int,
+ scaleFactor: String
+ ): RDD[String]
}
-
-abstract class Tables(sqlContext: SQLContext, scaleFactor: String,
- useDoubleForDecimal: Boolean = false, useStringForDate: Boolean = false)
- extends Serializable {
+abstract class Tables(
+ sqlContext: SQLContext,
+ scaleFactor: String,
+ useDoubleForDecimal: Boolean = false,
+ useStringForDate: Boolean = false
+) extends Serializable {
def dataGenerator: DataGenerator
def tables: Seq[Table]
@@ -104,18 +105,17 @@ abstract class Tables(sqlContext: SQLContext, scaleFactor: String,
private val log = LoggerFactory.getLogger(getClass)
def sparkContext = sqlContext.sparkContext
+ val spark = sqlContext.sparkSession
case class Table(name: String, partitionColumns: Seq[String], fields: StructField*) {
val schema = StructType(fields)
- def nonPartitioned: Table = {
- Table(name, Nil, fields : _*)
- }
+ def nonPartitioned: Table =
+ Table(name, Nil, fields: _*)
- /**
- * If convertToSchema is true, the data from generator will be parsed into columns and
- * converted to `schema`. Otherwise, it just outputs the raw data (as a single STRING column).
- */
+ /** If convertToSchema is true, the data from generator will be parsed into columns and
+ * converted to `schema`. Otherwise, it just outputs the raw data (as a single STRING column).
+ */
def df(convertToSchema: Boolean, numPartition: Int) = {
val generatedData = dataGenerator.generate(sparkContext, name, numPartition, scaleFactor)
val rows = generatedData.mapPartitions { iter =>
@@ -138,9 +138,10 @@ abstract class Tables(sqlContext: SQLContext, scaleFactor: String,
if (convertToSchema) {
val stringData =
- sqlContext.createDataFrame(
+ spark.createDataFrame(
rows,
- StructType(schema.fields.map(f => StructField(f.name, StringType))))
+ StructType(schema.fields.map(f => StructField(f.name, StringType)))
+ )
val convertedData = {
val columns = schema.fields.map { f =>
@@ -151,7 +152,7 @@ abstract class Tables(sqlContext: SQLContext, scaleFactor: String,
convertedData
} else {
- sqlContext.createDataFrame(rows, StructType(Seq(StructField("value", StringType))))
+ spark.createDataFrame(rows, StructType(Seq(StructField("value", StringType))))
}
}
@@ -159,33 +160,36 @@ abstract class Tables(sqlContext: SQLContext, scaleFactor: String,
val newFields = fields.map { field =>
val newDataType = field.dataType match {
case decimal: DecimalType if useDoubleForDecimal => DoubleType
- case date: DateType if useStringForDate => StringType
- case other => other
+ case date: DateType if useStringForDate => StringType
+ case other => other
}
field.copy(dataType = newDataType)
}
- Table(name, partitionColumns, newFields:_*)
+ Table(name, partitionColumns, newFields: _*)
}
def genData(
- location: String,
- format: String,
- overwrite: Boolean,
- clusterByPartitionColumns: Boolean,
- filterOutNullPartitionValues: Boolean,
- numPartitions: Int): Unit = {
+ location: String,
+ format: String,
+ overwrite: Boolean,
+ clusterByPartitionColumns: Boolean,
+ filterOutNullPartitionValues: Boolean,
+ numPartitions: Int
+ ): Unit = {
val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Ignore
- val data = df(format != "text", numPartitions)
+ val data = df(format != "text", numPartitions)
val tempTableName = s"${name}_text"
data.createOrReplaceTempView(tempTableName)
val writer = if (partitionColumns.nonEmpty) {
if (clusterByPartitionColumns) {
- val columnString = data.schema.fields.map { field =>
- field.name
- }.mkString(",")
+ val columnString = data.schema.fields
+ .map { field =>
+ field.name
+ }
+ .mkString(",")
val partitionColumnString = partitionColumns.mkString(",")
val predicates = if (filterOutNullPartitionValues) {
partitionColumns.map(col => s"$col IS NOT NULL").mkString("WHERE ", " AND ", "")
@@ -203,7 +207,7 @@ abstract class Tables(sqlContext: SQLContext, scaleFactor: String,
|DISTRIBUTE BY
| $partitionColumnString
""".stripMargin
- val grouped = sqlContext.sql(query)
+ val grouped = spark.sql(query)
println(s"Pre-clustering with partitioning columns with query $query.")
log.info(s"Pre-clustering with partitioning columns with query $query.")
grouped.write
@@ -216,13 +220,18 @@ abstract class Tables(sqlContext: SQLContext, scaleFactor: String,
// in case data has more than maxRecordsPerFile, split into multiple writers to improve datagen speed
// files will be truncated to maxRecordsPerFile value, so the final result will be the same
val numRows = data.count
- val maxRecordPerFile = util.Try(sqlContext.getConf("spark.sql.files.maxRecordsPerFile").toInt).getOrElse(0)
+ val maxRecordPerFile =
+ util.Try(spark.conf.get("spark.sql.files.maxRecordsPerFile").toInt).getOrElse(0)
- println(s"Data has $numRows rows clustered $clusterByPartitionColumns for $maxRecordPerFile")
- log.info(s"Data has $numRows rows clustered $clusterByPartitionColumns for $maxRecordPerFile")
+ println(
+ s"Data has $numRows rows clustered $clusterByPartitionColumns for $maxRecordPerFile"
+ )
+ log.info(
+ s"Data has $numRows rows clustered $clusterByPartitionColumns for $maxRecordPerFile"
+ )
if (maxRecordPerFile > 0 && numRows > maxRecordPerFile) {
- val numFiles = (numRows.toDouble/maxRecordPerFile).ceil.toInt
+ val numFiles = (numRows.toDouble / maxRecordPerFile).ceil.toInt
println(s"Coalescing into $numFiles files")
log.info(s"Coalescing into $numFiles files")
data.coalesce(numFiles).write
@@ -235,49 +244,96 @@ abstract class Tables(sqlContext: SQLContext, scaleFactor: String,
}
writer.format(format).mode(mode)
if (partitionColumns.nonEmpty) {
- writer.partitionBy(partitionColumns : _*)
+ writer.partitionBy(partitionColumns: _*)
}
println(s"Generating table $name in database to $location with save mode $mode.")
log.info(s"Generating table $name in database to $location with save mode $mode.")
writer.save(location)
- sqlContext.dropTempTable(tempTableName)
+ spark.catalog.dropTempView(tempTableName)
}
- def createExternalTable(location: String, format: String, databaseName: String,
- overwrite: Boolean, discoverPartitions: Boolean = true): Unit = {
-
- val qualifiedTableName = databaseName + "." + name
- val tableExists = sqlContext.tableNames(databaseName).contains(name)
+ def createExternalTable(
+ location: String,
+ format: String,
+ databaseName: String,
+ overwrite: Boolean,
+ discoverPartitions: Boolean = true,
+ isPartitioned: Boolean = false
+ ): Unit = {
+
+ val qualifiedTableName = s"`$databaseName`.`$name`"
+ val tableExists = spark.catalog.tableExists(databaseName, name)
if (overwrite) {
- sqlContext.sql(s"DROP TABLE IF EXISTS $databaseName.$name")
+ spark.sql(s"DROP TABLE IF EXISTS $qualifiedTableName")
}
if (!tableExists || overwrite) {
- println(s"Creating external table $name in database $databaseName using data stored in $location.")
- log.info(s"Creating external table $name in database $databaseName using data stored in $location.")
- sqlContext.createExternalTable(qualifiedTableName, location, format)
+ println(
+ s"Creating external table $name in database $databaseName using data stored in $location."
+ )
+ log.info(
+ s"Creating external table $name in database $databaseName using data stored in $location."
+ )
+
+ val ddlSchema = schema.toDDL
+
+ // Only add PARTITIONED BY when the caller explicitly signals that data is stored
+ // in Hive-style col=value/ directories. For flat files (e.g. JSON, Parquet without
+ // partition directories), keep isPartitioned=false (the default) to avoid 0-row tables.
+ val partitioningClause = if (isPartitioned && partitionColumns.nonEmpty) {
+ s"PARTITIONED BY (${partitionColumns.mkString("`", "`, `", "`")})"
+ } else {
+ ""
+ }
+
+ val ddl =
+ s"""CREATE EXTERNAL TABLE IF NOT EXISTS $qualifiedTableName ($ddlSchema)
+ |USING $format
+ |$partitioningClause
+ |LOCATION '$location'
+ """.stripMargin
+
+ spark.sql(ddl)
}
- if (partitionColumns.nonEmpty && discoverPartitions) {
- println(s"Discovering partitions for table $name.")
- log.info(s"Discovering partitions for table $name.")
- sqlContext.sql(s"ALTER TABLE $databaseName.$name RECOVER PARTITIONS")
+
+ val formatLower = format.toLowerCase
+ val skipRecover = Set("delta", "iceberg")
+ if (
+ isPartitioned && partitionColumns.nonEmpty && discoverPartitions && !skipRecover.contains(
+ formatLower
+ )
+ ) {
+ println(s"Attempting partition discovery for table $name.")
+ log.info(s"Attempting partition discovery for table $name.")
+ try {
+ spark.sql(s"MSCK REPAIR TABLE $qualifiedTableName")
+ println(s"Partition discovery succeeded for table $name.")
+ log.info(s"Partition discovery succeeded for table $name.")
+ } catch {
+ case e: Exception =>
+ println(
+ s"[INFO] Partition discovery skipped for table $name " +
+ s"(data may be in flat files, not Hive-style col=value/ directories)."
+ )
+ log.info(s"Partition discovery skipped for $name: ${e.getMessage}")
+ }
}
}
def createTemporaryTable(location: String, format: String): Unit = {
println(s"Creating temporary table $name using data stored in $location.")
log.info(s"Creating temporary table $name using data stored in $location.")
- sqlContext.read.format(format).load(location).createOrReplaceTempView(name)
+ spark.read.format(format).load(location).createOrReplaceTempView(name)
}
def analyzeTable(databaseName: String, analyzeColumns: Boolean = false): Unit = {
println(s"Analyzing table $name.")
log.info(s"Analyzing table $name.")
- sqlContext.sql(s"ANALYZE TABLE $databaseName.$name COMPUTE STATISTICS")
+ spark.sql(s"ANALYZE TABLE $databaseName.$name COMPUTE STATISTICS")
if (analyzeColumns) {
val allColumns = fields.map(_.name).mkString(", ")
println(s"Analyzing table $name columns $allColumns.")
log.info(s"Analyzing table $name columns $allColumns.")
- sqlContext.sql(s"ANALYZE TABLE $databaseName.$name COMPUTE STATISTICS FOR COLUMNS $allColumns")
+ spark.sql(s"ANALYZE TABLE $databaseName.$name COMPUTE STATISTICS FOR COLUMNS $allColumns")
}
}
}
@@ -290,7 +346,8 @@ abstract class Tables(sqlContext: SQLContext, scaleFactor: String,
clusterByPartitionColumns: Boolean,
filterOutNullPartitionValues: Boolean,
tableFilter: String = "",
- numPartitions: Int = 100): Unit = {
+ numPartitions: Int = 100
+ ): Unit = {
var tablesToBeGenerated = if (partitionTables) {
tables
} else {
@@ -306,13 +363,26 @@ abstract class Tables(sqlContext: SQLContext, scaleFactor: String,
tablesToBeGenerated.foreach { table =>
val tableLocation = s"$location/${table.name}"
- table.genData(tableLocation, format, overwrite, clusterByPartitionColumns,
- filterOutNullPartitionValues, numPartitions)
+ table.genData(
+ tableLocation,
+ format,
+ overwrite,
+ clusterByPartitionColumns,
+ filterOutNullPartitionValues,
+ numPartitions
+ )
}
}
- def createExternalTables(location: String, format: String, databaseName: String,
- overwrite: Boolean, discoverPartitions: Boolean, tableFilter: String = ""): Unit = {
+ def createExternalTables(
+ location: String,
+ format: String,
+ databaseName: String,
+ overwrite: Boolean,
+ discoverPartitions: Boolean,
+ tableFilter: String = "",
+ isPartitioned: Boolean = false
+ ): Unit = {
val filtered = if (tableFilter.isEmpty) {
tables
@@ -320,12 +390,19 @@ abstract class Tables(sqlContext: SQLContext, scaleFactor: String,
tables.filter(_.name == tableFilter)
}
- sqlContext.sql(s"CREATE DATABASE IF NOT EXISTS $databaseName")
+ spark.sql(s"CREATE DATABASE IF NOT EXISTS $databaseName")
filtered.foreach { table =>
val tableLocation = s"$location/${table.name}"
- table.createExternalTable(tableLocation, format, databaseName, overwrite, discoverPartitions)
+ table.createExternalTable(
+ tableLocation,
+ format,
+ databaseName,
+ overwrite,
+ discoverPartitions,
+ isPartitioned
+ )
}
- sqlContext.sql(s"USE $databaseName")
+ spark.sql(s"USE $databaseName")
println(s"The current database has been set to $databaseName.")
log.info(s"The current database has been set to $databaseName.")
}
@@ -342,7 +419,11 @@ abstract class Tables(sqlContext: SQLContext, scaleFactor: String,
}
}
- def analyzeTables(databaseName: String, analyzeColumns: Boolean = false, tableFilter: String = ""): Unit = {
+ def analyzeTables(
+ databaseName: String,
+ analyzeColumns: Boolean = false,
+ tableFilter: String = ""
+ ): Unit = {
val filtered = if (tableFilter.isEmpty) {
tables
} else {
@@ -353,5 +434,4 @@ abstract class Tables(sqlContext: SQLContext, scaleFactor: String,
}
}
-
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/bigdata/BigData.scala b/src/main/scala/com/databricks/spark/sql/perf/bigdata/BigData.scala
index e69de29b..454276f2 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/bigdata/BigData.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/bigdata/BigData.scala
@@ -0,0 +1 @@
+package com.databricks.spark.sql.perf.bigdata
diff --git a/src/main/scala/com/databricks/spark/sql/perf/bigdata/Queries.scala b/src/main/scala/com/databricks/spark/sql/perf/bigdata/Queries.scala
index e1b1c69b..43bd70bd 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/bigdata/Queries.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/bigdata/Queries.scala
@@ -16,7 +16,7 @@
package com.databricks.spark.sql.perf.bigdata
-import com.databricks.spark.sql.perf.{ExecutionMode, Benchmark}
+import com.databricks.spark.sql.perf.{Benchmark, ExecutionMode}
trait Queries extends Benchmark {
@@ -25,8 +25,7 @@ trait Queries extends Benchmark {
val queries1to3 = Seq(
Query(
name = "q1A",
- sqlText =
- """
+ sqlText = """
|SELECT
| pageURL,
| pageRank
@@ -35,12 +34,11 @@ trait Queries extends Benchmark {
| pageRank > 1000
""".stripMargin,
description = "",
- executionMode = ForeachResults),
-
+ executionMode = ForeachResults
+ ),
Query(
name = "q1B",
- sqlText =
- """
+ sqlText = """
|SELECT
| pageURL,
| pageRank
@@ -49,12 +47,11 @@ trait Queries extends Benchmark {
| pageRank > 100
""".stripMargin,
description = "",
- executionMode = ForeachResults),
-
+ executionMode = ForeachResults
+ ),
Query(
name = "q1C",
- sqlText =
- """
+ sqlText = """
|SELECT
| pageURL,
| pageRank
@@ -63,12 +60,11 @@ trait Queries extends Benchmark {
| pageRank > 10
""".stripMargin,
description = "",
- executionMode = ForeachResults),
-
+ executionMode = ForeachResults
+ ),
Query(
name = "q2A",
- sqlText =
- """
+ sqlText = """
|SELECT
| SUBSTR(sourceIP, 1, 8),
| SUM(adRevenue)
@@ -77,12 +73,11 @@ trait Queries extends Benchmark {
| SUBSTR(sourceIP, 1, 8)
""".stripMargin,
description = "",
- executionMode = ForeachResults),
-
+ executionMode = ForeachResults
+ ),
Query(
name = "q2B",
- sqlText =
- """
+ sqlText = """
|SELECT
| SUBSTR(sourceIP, 1, 10),
| SUM(adRevenue)
@@ -91,12 +86,11 @@ trait Queries extends Benchmark {
| SUBSTR(sourceIP, 1, 10)
""".stripMargin,
description = "",
- executionMode = ForeachResults),
-
+ executionMode = ForeachResults
+ ),
Query(
name = "q2C",
- sqlText =
- """
+ sqlText = """
|SELECT
| SUBSTR(sourceIP, 1, 12),
| SUM(adRevenue)
@@ -105,12 +99,11 @@ trait Queries extends Benchmark {
| SUBSTR(sourceIP, 1, 12)
""".stripMargin,
description = "",
- executionMode = ForeachResults),
-
+ executionMode = ForeachResults
+ ),
Query(
name = "q3A",
- sqlText =
- """
+ sqlText = """
|SELECT sourceIP, totalRevenue, avgPageRank
|FROM
| (SELECT sourceIP,
@@ -124,12 +117,11 @@ trait Queries extends Benchmark {
|ORDER BY totalRevenue DESC LIMIT 1
""".stripMargin,
description = "",
- executionMode = ForeachResults),
-
+ executionMode = ForeachResults
+ ),
Query(
name = "q3B",
- sqlText =
- """
+ sqlText = """
|SELECT sourceIP, totalRevenue, avgPageRank
|FROM
| (SELECT sourceIP,
@@ -143,8 +135,8 @@ trait Queries extends Benchmark {
|ORDER BY totalRevenue DESC LIMIT 1
""".stripMargin,
description = "",
- executionMode = ForeachResults),
-
+ executionMode = ForeachResults
+ ),
Query(
name = "q3C",
sqlText = """
@@ -161,6 +153,7 @@ trait Queries extends Benchmark {
|ORDER BY totalRevenue DESC LIMIT 1
""".stripMargin,
description = "",
- executionMode = ForeachResults)
+ executionMode = ForeachResults
+ )
)
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/bigdata/Tables.scala b/src/main/scala/com/databricks/spark/sql/perf/bigdata/Tables.scala
index e69de29b..454276f2 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/bigdata/Tables.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/bigdata/Tables.scala
@@ -0,0 +1 @@
+package com.databricks.spark.sql.perf.bigdata
diff --git a/src/main/scala/com/databricks/spark/sql/perf/handleResults.scala b/src/main/scala/com/databricks/spark/sql/perf/handleResults.scala
index a1c07de7..a5311d01 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/handleResults.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/handleResults.scala
@@ -19,7 +19,7 @@ package com.databricks.spark.sql.perf
import org.apache.spark.sql.SQLContext
case class Results(resultsLocation: String, @transient sqlContext: SQLContext) {
+ val spark = sqlContext.sparkSession
def allResults =
- sqlContext.read.json(
- sqlContext.sparkContext.textFile(s"$resultsLocation/*/"))
+ spark.read.json(spark.sparkContext.textFile(s"$resultsLocation/*/"))
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala
index 9e00a45e..dc7adf71 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala
@@ -8,67 +8,58 @@ import org.apache.spark.sql.functions._
import com.databricks.spark.sql.perf._
-/**
- * The description of a benchmark for an ML algorithm. It follows a simple, standard proceduce:
- * - generate some test and training data
- * - generate a model against the training data
- * - score the model against the training data
- * - score the model against the test data
- *
- * You should not assume that your implementation can carry state around. If some state is needed,
- * consider adding it to the context.
- *
- * It is assumed that the implementation is going to be an object.
- */
+/** The description of a benchmark for an ML algorithm. It follows a simple, standard proceduce:
+ * - generate some test and training data
+ * - generate a model against the training data
+ * - score the model against the training data
+ * - score the model against the test data
+ *
+ * You should not assume that your implementation can carry state around. If some state is needed,
+ * consider adding it to the context.
+ *
+ * It is assumed that the implementation is going to be an object.
+ */
trait BenchmarkAlgorithm {
def trainingDataSet(ctx: MLBenchContext): DataFrame
def testDataSet(ctx: MLBenchContext): DataFrame
- /**
- * Create an [[Estimator]] or [[Transformer]] with params set from the given [[MLBenchContext]].
- */
+ /** Create an [[Estimator]] or [[Transformer]] with params set from the given [[MLBenchContext]].
+ */
def getPipelineStage(ctx: MLBenchContext): PipelineStage
- /**
- * The unnormalized score of the training procedure on a dataset. The normalization is
- * performed by the caller.
- * This calls `count()` on the transformed data to attempt to materialize the result for
- * recording timing metrics.
- */
+ /** The unnormalized score of the training procedure on a dataset. The normalization is performed
+ * by the caller. This calls `count()` on the transformed data to attempt to materialize the
+ * result for recording timing metrics.
+ */
@throws[Exception]("if scoring fails")
- def score(
- ctx: MLBenchContext,
- testSet: DataFrame,
- model: Transformer): MLMetric = {
+ def score(ctx: MLBenchContext, testSet: DataFrame, model: Transformer): MLMetric = {
val output = model.transform(testSet)
// We create a useless UDF to make sure the entire DataFrame is instantiated.
- val fakeUDF = udf { (_: Any) => 0 }
+ val fakeUDF = udf((_: Any) => 0)
val columns = testSet.columns
- output.select(sum(fakeUDF(struct(columns.map(col) : _*)))).first()
+ output.select(sum(fakeUDF(struct(columns.map(col): _*)))).first()
MLMetric.Invalid
}
- def name: String = {
+ def name: String =
this.getClass.getCanonicalName.replace("$", "")
- }
- /**
- * Test additional methods for some algorithms.
- *
- * @param transformer The transformer which includes additional methods.
- * @return A map which key is the additional method name, and value is a function which runs
- * the corresponding method.
- */
- def testAdditionalMethods(
- ctx: MLBenchContext,
- transformer: Transformer): Map[String, () => _] = Map.empty[String, () => _]
+ /** Test additional methods for some algorithms.
+ *
+ * @param transformer
+ * The transformer which includes additional methods.
+ * @return
+ * A map which key is the additional method name, and value is a function which runs the
+ * corresponding method.
+ */
+ def testAdditionalMethods(ctx: MLBenchContext, transformer: Transformer): Map[String, () => _] =
+ Map.empty[String, () => _]
}
-/**
- * Uses an evaluator to perform the scoring.
- */
+/** Uses an evaluator to perform the scoring.
+ */
trait ScoringWithEvaluator {
self: BenchmarkAlgorithm =>
@@ -77,9 +68,10 @@ trait ScoringWithEvaluator {
final override def score(
ctx: MLBenchContext,
testSet: DataFrame,
- model: Transformer): MLMetric = {
+ model: Transformer
+ ): MLMetric = {
val results = model.transform(testSet)
- val eval = evaluator(ctx)
+ val eval = evaluator(ctx)
val metricName = if (eval.hasParam("metricName")) {
val param = eval.getParam("metricName")
eval.getOrDefault(param).toString
@@ -91,10 +83,9 @@ trait ScoringWithEvaluator {
}
}
-/**
- * Builds the training set for an initial dataset and an initial model. Useful for validating a
- * trained model against a given model.
- */
+/** Builds the training set for an initial dataset and an initial model. Useful for validating a
+ * trained model against a given model.
+ */
trait TrainingSetFromTransformer {
self: BenchmarkAlgorithm =>
@@ -104,8 +95,8 @@ trait TrainingSetFromTransformer {
final override def trainingDataSet(ctx: MLBenchContext): DataFrame = {
val initial = initialData(ctx)
- val model = trueModel(ctx)
- val fCol = col("features")
+ val model = trueModel(ctx)
+ val fCol = col("features")
// Special case for the trees: we need to set the number of labels.
// numClasses is set? We will add the number of classes to the final column.
val lCol = ctx.params.numClasses match {
@@ -124,9 +115,8 @@ trait TrainingSetFromTransformer {
}
}
-/**
- * The test data is the same as the training data.
- */
+/** The test data is the same as the training data.
+ */
trait TestFromTraining {
self: BenchmarkAlgorithm =>
@@ -145,4 +135,3 @@ trait TestFromTraining {
self.trainingDataSet(ctx2)
}
}
-
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLBenchContext.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLBenchContext.scala
index b8971fe4..22a405bd 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLBenchContext.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLBenchContext.scala
@@ -2,38 +2,33 @@ package com.databricks.spark.sql.perf.mllib
import java.util.Random
-import com.databricks.spark.sql.perf.{MLParams}
-import org.apache.spark.sql.SQLContext
+import com.databricks.spark.sql.perf.MLParams
+import org.apache.spark.sql.{SQLContext, SparkSession}
+/** All the information required to run a test.
+ *
+ * @param params
+ * @param sqlContext
+ */
+case class MLBenchContext(params: MLParams, sqlContext: SQLContext) {
-/**
- * All the information required to run a test.
- *
- * @param params
- * @param sqlContext
- */
-case class MLBenchContext(
- params: MLParams,
- sqlContext: SQLContext) {
+ val spark = sqlContext.sparkSession
// Some seed fixed for the context.
- private val internalSeed: Long = {
+ private val internalSeed: Long =
params.randomSeed.map(_.toLong).getOrElse {
throw new Exception("You need te specify the random seed")
}
- }
- /**
- * A fixed seed for this class. This function will always return the same value.
- *
- * @return
- */
+ /** A fixed seed for this class. This function will always return the same value.
+ *
+ * @return
+ */
def seed(): Long = internalSeed
- /**
- * Creates a new generator. The generator will always start with the same state.
- *
- * @return
- */
+ /** Creates a new generator. The generator will always start with the same state.
+ *
+ * @return
+ */
def newGenerator(): Random = new Random(seed())
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLBenchmarks.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLBenchmarks.scala
index 13b5d143..87395675 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLBenchmarks.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLBenchmarks.scala
@@ -2,34 +2,33 @@ package com.databricks.spark.sql.perf.mllib
import com.databricks.spark.sql.perf.mllib.classification.LogisticRegression
import org.apache.spark.SparkContext
-import org.apache.spark.sql.{SQLContext,SparkSession}
+import org.apache.spark.sql.{SQLContext, SparkSession}
-import com.databricks.spark.sql.perf.{MLParams}
+import com.databricks.spark.sql.perf.MLParams
import OptionImplicits._
-case class MLTest(
- benchmark: BenchmarkAlgorithm,
- params: MLParams)
+case class MLTest(benchmark: BenchmarkAlgorithm, params: MLParams)
// Example on how to create benchmarks using the API.
object MLBenchmarks {
// The list of standard benchmarks that we are going to run for ML.
val benchmarks: Seq[MLTest] = List(
- MLTest(
- LogisticRegression,
- new MLParams(
- numFeatures = 10,
- numExamples = 10,
- numTestExamples = 10,
- numPartitions = 3,
- regParam = 1,
- tol = 0.2)
+ MLTest(
+ LogisticRegression,
+ new MLParams(
+ numFeatures = 10,
+ numExamples = 10,
+ numTestExamples = 10,
+ numPartitions = 3,
+ regParam = 1,
+ tol = 0.2
)
+ )
)
- val sparkSession = SparkSession.builder.getOrCreate()
+ val sparkSession = SparkSession.builder.getOrCreate()
val sqlContext: SQLContext = sparkSession.sqlContext
- val context = sqlContext.sparkContext
+ val context = sqlContext.sparkContext
def benchmarkObjects: Seq[MLPipelineStageBenchmarkable] = benchmarks.map { mlb =>
new MLPipelineStageBenchmarkable(mlb.params, mlb.benchmark, sqlContext)
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLLib.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLLib.scala
index c0bf70e0..e139bc32 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLLib.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLLib.scala
@@ -1,39 +1,35 @@
package com.databricks.spark.sql.perf.mllib
-
import scala.io.Source
import scala.language.implicitConversions
import org.slf4j.LoggerFactory
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
import org.apache.spark.{SparkConf, SparkContext}
import com.databricks.spark.sql.perf._
+class MLLib(sqlContext: SQLContext) extends Benchmark(sqlContext) with Serializable {
-class MLLib(sqlContext: SQLContext)
- extends Benchmark(sqlContext) with Serializable {
-
- def this() = this(SQLContext.getOrCreate(SparkContext.getOrCreate()))
+ def this() = this(SparkSession.builder.getOrCreate().sqlContext)
}
object MLLib {
- /**
- * Runs a set of preprogrammed experiments and blocks on completion.
- *
- * @param runConfig a configuration that is av
- * @return
- */
+ /** Runs a set of preprogrammed experiments and blocks on completion.
+ *
+ * @param runConfig
+ * a configuration that is av
+ * @return
+ */
lazy val logger = LoggerFactory.getLogger(this.getClass.getName)
def runDefault(runConfig: RunConfig): DataFrame = {
- val ml = new MLLib()
+ val ml = new MLLib()
val benchmarks = MLBenchmarks.benchmarkObjects
- val e = ml.runExperiment(
- executionsToRun = benchmarks)
+ val e = ml.runExperiment(executionsToRun = benchmarks)
e.waitForFinish(1000 * 60 * 30)
logger.info("Run finished")
e.getCurrentResults()
@@ -47,44 +43,43 @@ object MLLib {
val smallConfig: String = getConfig("config/mllib-small.yaml")
val largeConfig: String = getConfig("config/mllib-large.yaml")
- /**
- * Entry point for running ML tests. Expects a single command-line argument: the path to
- * a YAML config file specifying which ML tests to run and their parameters.
- * @param args command line args
- */
+ /** Entry point for running ML tests. Expects a single command-line argument: the path to a YAML
+ * config file specifying which ML tests to run and their parameters.
+ * @param args
+ * command line args
+ */
def main(args: Array[String]): Unit = {
val configFile = args(0)
run(yamlFile = configFile)
}
- private[mllib] def getConf(yamlFile: String = null, yamlConfig: String = null): YamlConfig = {
+ private[mllib] def getConf(yamlFile: String = null, yamlConfig: String = null): YamlConfig =
Option(yamlFile).map(YamlConfig.readFile).getOrElse {
require(yamlConfig != null)
YamlConfig.readString(yamlConfig)
}
- }
private[mllib] def getBenchmarks(conf: YamlConfig): Seq[MLPipelineStageBenchmarkable] = {
- val sqlContext = com.databricks.spark.sql.perf.mllib.MLBenchmarks.sqlContext
+ val sqlContext = com.databricks.spark.sql.perf.mllib.MLBenchmarks.sqlContext
val benchmarksDescriptions = conf.runnableBenchmarks
benchmarksDescriptions.map { mlb =>
new MLPipelineStageBenchmarkable(mlb.params, mlb.benchmark, sqlContext)
}
}
- /**
- * Runs all the experiments and blocks on completion
- *
- * @param yamlFile a file name
- * @return
- */
+ /** Runs all the experiments and blocks on completion
+ *
+ * @param yamlFile
+ * a file name
+ * @return
+ */
def run(yamlFile: String = null, yamlConfig: String = null): DataFrame = {
logger.info("Starting run")
- val conf = getConf(yamlFile, yamlConfig)
+ val conf = getConf(yamlFile, yamlConfig)
val sparkConf = new SparkConf().setAppName("MLlib QA").setMaster("local[2]")
- val sc = SparkContext.getOrCreate(sparkConf)
+ val sc = SparkContext.getOrCreate(sparkConf)
sc.setLogLevel("INFO")
- val b = new com.databricks.spark.sql.perf.mllib.MLLib()
+ val b = new com.databricks.spark.sql.perf.mllib.MLLib()
val benchmarks = getBenchmarks(conf)
println(s"${benchmarks.size} benchmarks identified:")
val str = benchmarks.map(_.prettyPrint).mkString("\n")
@@ -94,7 +89,8 @@ object MLLib {
executionsToRun = benchmarks,
iterations = 1, // If you want to increase the number of iterations, add more seeds
resultLocation = conf.output,
- forkThread = false)
+ forkThread = false
+ )
e.waitForFinish(conf.timeout.toSeconds.toInt)
logger.info("Run finished")
e.getCurrentResults()
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLPipelineStageBenchmarkable.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLPipelineStageBenchmarkable.scala
index 8296f46b..7bb7965e 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLPipelineStageBenchmarkable.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLPipelineStageBenchmarkable.scala
@@ -12,15 +12,16 @@ import com.databricks.spark.sql.perf._
class MLPipelineStageBenchmarkable(
params: MLParams,
test: BenchmarkAlgorithm,
- sqlContext: SQLContext)
- extends Benchmarkable with Serializable {
+ sqlContext: SQLContext
+) extends Benchmarkable
+ with Serializable {
import MLPipelineStageBenchmarkable._
- private var testData: DataFrame = null
- private var trainingData: DataFrame = null
+ private var testData: DataFrame = null
+ private var trainingData: DataFrame = null
private var testDataCount: Option[Long] = None
- private val param = MLBenchContext(params, sqlContext)
+ private val param = MLBenchContext(params, sqlContext)
override val name = test.name
@@ -43,9 +44,11 @@ class MLPipelineStageBenchmarkable(
}
override protected def doBenchmark(
- includeBreakdown: Boolean,
- description: String,
- messages: ArrayBuffer[String]): BenchmarkResult = {
+ includeBreakdown: Boolean,
+ description: String,
+ messages: ArrayBuffer[String],
+ iteration: Int = 1
+ ): BenchmarkResult =
try {
val (trainingTime, model: Transformer) = measureTime {
logger.info(s"$this: train: trainingSet=${trainingData.schema}")
@@ -54,8 +57,11 @@ class MLPipelineStageBenchmarkable(
case transformer: Transformer =>
transformer.transform(trainingData)
transformer
- case other: Any => throw new UnsupportedOperationException("Algorithm to benchmark must" +
- s" be an estimator or transformer, found ${other.getClass} instead.")
+ case other: Any =>
+ throw new UnsupportedOperationException(
+ "Algorithm to benchmark must" +
+ s" be an estimator or transformer, found ${other.getClass} instead."
+ )
}
}
logger.info(s"model: $model")
@@ -63,30 +69,38 @@ class MLPipelineStageBenchmarkable(
test.score(param, trainingData, model)
}
val metricTrainingTime = MLMetric("training.time", trainingTime.toMillis, false)
- val metricTraining = MLMetric("training."+scoreTraining.metricName,
+ val metricTraining = MLMetric(
+ "training." + scoreTraining.metricName,
scoreTraining.metricValue,
- scoreTraining.isLargerBetter)
+ scoreTraining.isLargerBetter
+ )
val (scoreTestTime, scoreTest) = measureTime {
test.score(param, testData, model)
}
val metricTestTime = MLMetric("test.time", scoreTestTime.toMillis, false)
- val metricTest = MLMetric("test."+scoreTraining.metricName,
+ val metricTest = MLMetric(
+ "test." + scoreTraining.metricName,
scoreTraining.metricValue,
- scoreTraining.isLargerBetter)
-
- logger.info(s"$this doBenchmark: Trained model in ${trainingTime.toMillis / 1000.0}" +
- s" s, Scored training dataset in ${scoreTrainTime.toMillis / 1000.0} s," +
- s" test dataset in ${scoreTestTime.toMillis / 1000.0} s")
-
- val additionalTests = test.testAdditionalMethods(param, model).map {
- tuple =>
- val (additionalMethodTime, _) = measureTime { tuple._2() }
+ scoreTraining.isLargerBetter
+ )
+
+ logger.info(
+ s"$this doBenchmark: Trained model in ${trainingTime.toMillis / 1000.0}" +
+ s" s, Scored training dataset in ${scoreTrainTime.toMillis / 1000.0} s," +
+ s" test dataset in ${scoreTestTime.toMillis / 1000.0} s"
+ )
+
+ val additionalTests = test
+ .testAdditionalMethods(param, model)
+ .map { tuple =>
+ val (additionalMethodTime, _) = measureTime(tuple._2())
MLMetric(tuple._1, additionalMethodTime.toMillis, false)
- }.toArray
+ }
+ .toArray
val mlMetrics = Array(metricTrainingTime, metricTraining, metricTestTime, metricTest) ++
additionalTests
- val paramsMap = params.toMap
+ val paramsMap = params.toMap
val benchmarkId = name.split('.').last + "_" + paramsMap.hashCode.abs
BenchmarkResult(
@@ -95,27 +109,27 @@ class MLPipelineStageBenchmarkable(
parameters = paramsMap,
executionTime = Some(trainingTime.toMillis),
mlResult = Some(mlMetrics),
- benchmarkId = Some(benchmarkId))
+ benchmarkId = Some(benchmarkId)
+ )
} catch {
case e: Exception =>
BenchmarkResult(
name = name,
mode = executionMode.toString,
parameters = params.toMap,
- failure = Some(Failure(e.getClass.getSimpleName,
- e.getMessage + ":\n" + e.getStackTraceString)))
+ failure =
+ Some(Failure(e.getClass.getSimpleName, e.getMessage + ":\n" + e.getStackTraceString))
+ )
} finally {
Option(testData).map(_.unpersist())
Option(trainingData).map(_.unpersist())
}
- }
def prettyPrint: String = {
val paramString = pprint(params).mkString("\n")
s"$test\n$paramString"
}
-
}
object MLPipelineStageBenchmarkable {
@@ -123,15 +137,14 @@ object MLPipelineStageBenchmarkable {
val m = getCCParams(p)
m.flatMap {
case (key, Some(value: Any)) => Some(s" $key=$value")
- case _ => None
- } .toSeq
+ case _ => None
+ }.toSeq
}
// From http://stackoverflow.com/questions/1226555/case-class-to-map-in-scala
private def getCCParams(cc: AnyRef) =
- (Map[String, Any]() /: cc.getClass.getDeclaredFields) {(a, f) =>
+ (Map[String, Any]() /: cc.getClass.getDeclaredFields) { (a, f) =>
f.setAccessible(true)
a + (f.getName -> f.get(cc))
}
}
-
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/OptionImplicits.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/OptionImplicits.scala
index ef905258..15169c5b 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/OptionImplicits.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/OptionImplicits.scala
@@ -2,32 +2,29 @@ package com.databricks.spark.sql.perf.mllib
import scala.language.implicitConversions
-/**
- * Implicits to transparently convert some Option[X] to X and vice-versa.
- *
- * This is usually dangerous to do, but in our case, the config is expressed through Options and
- * it alleviates the need to manually box values.
- */
+/** Implicits to transparently convert some Option[X] to X and vice-versa.
+ *
+ * This is usually dangerous to do, but in our case, the config is expressed through Options and it
+ * alleviates the need to manually box values.
+ */
object OptionImplicits {
// The following implicits are unrolled for safety:
private def oX2X[A](x: Option[A]): A = x.get
- def checkLong(x: Option[Long]): Option[Long] = {
+ def checkLong(x: Option[Long]): Option[Long] =
x.asInstanceOf[Option[Any]] match {
case Some(u: java.lang.Integer) => Some(u.toLong)
- case Some(u: java.lang.Long) => Some(u.toLong)
- case _ => x
+ case Some(u: java.lang.Long) => Some(u.toLong)
+ case _ => x
}
- }
- def checkDouble(x: Option[Double]): Option[Double] = {
+ def checkDouble(x: Option[Double]): Option[Double] =
x.asInstanceOf[Option[Any]] match {
case Some(u: java.lang.Integer) => Some(u.toDouble)
- case Some(u: java.lang.Long) => Some(u.toDouble)
- case Some(u: java.lang.Double) => Some(u.toDouble)
- case _ => x
+ case Some(u: java.lang.Long) => Some(u.toDouble)
+ case Some(u: java.lang.Double) => Some(u.toDouble)
+ case _ => x
}
- }
implicit def oD2D(x: Option[Double]): Double = oX2X(x)
@@ -37,9 +34,9 @@ object OptionImplicits {
implicit def oL2L(x: Option[Long]): Long = oX2X(x)
- implicit def l2lo(x: Long): Option[Long] = checkLong(Option(x))
- implicit def i2lo(x: Int): Option[Long] = Option(x.toLong)
- implicit def i2io(x: Int): Option[Int] = Option(x)
+ implicit def l2lo(x: Long): Option[Long] = checkLong(Option(x))
+ implicit def i2lo(x: Int): Option[Long] = Option(x.toLong)
+ implicit def i2io(x: Int): Option[Int] = Option(x)
implicit def d2do(x: Double): Option[Double] = Option(x)
- implicit def i2do(x: Int): Option[Double] = Option(x)
-}
\ No newline at end of file
+ implicit def i2do(x: Int): Option[Double] = Option(x)
+}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/ReflectionUtils.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/ReflectionUtils.scala
index 75a29496..7e688a84 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/ReflectionUtils.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/ReflectionUtils.scala
@@ -6,16 +6,13 @@ import scala.reflect.runtime.universe._
/** Exposes methods to simplify implementation of classes like MLParams. */
private[perf] object ReflectionUtils {
- private def getConstructor[T: TypeTag: ClassTag](obj: T): MethodSymbol = {
+ private def getConstructor[T: TypeTag: ClassTag](obj: T): MethodSymbol =
typeOf[T].declaration(nme.CONSTRUCTOR).asMethod
- }
- /**
- * Given an instance [[obj]] of a class whose constructor arguments are all of type Option[Any],
- * returns a map of key-value pairs (argName -> argValue) where argName is the name
- * of a constructor argument with a defined (not None) value and argValue is the corresponding
- * value.
- */
+ /** Given an instance [[obj]] of a class whose constructor arguments are all of type Option[Any],
+ * returns a map of key-value pairs (argName -> argValue) where argName is the name of a
+ * constructor argument with a defined (not None) value and argValue is the corresponding value.
+ */
def getConstructorArgs[T: TypeTag: ClassTag](obj: T): Map[String, Any] = {
// Get constructor of passed-in instance
val constructor = getConstructor(obj)
@@ -23,15 +20,18 @@ private[perf] object ReflectionUtils {
constructor.paramss.flatten.flatMap { (param: Symbol) =>
// Get name and value of the constructor argument
val paramName = param.name.toString
- val getter = obj.getClass.getDeclaredField(paramName)
+ val getter = obj.getClass.getDeclaredField(paramName)
getter.setAccessible(true)
val paramValue = getter.get(obj)
// If the constructor argument is defined, include it in our output map
paramValue match {
case value: Option[Any] => if (value.isDefined) Seq(paramName -> paramValue) else Seq.empty
- case _ => throw new UnsupportedOperationException("ReflectionUtils.getConstructorArgs " +
- "can only be called on instances of classes whose constructor arguments are all of " +
- s"type Option[Any]; constructor argument ${paramName} had invalid type.")
+ case _ =>
+ throw new UnsupportedOperationException(
+ "ReflectionUtils.getConstructorArgs " +
+ "can only be called on instances of classes whose constructor arguments are all of " +
+ s"type Option[Any]; constructor argument $paramName had invalid type."
+ )
}
}.toMap
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/TreeOrForestEstimator.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/TreeOrForestEstimator.scala
index 0bf8b536..e43728d6 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/TreeOrForestEstimator.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/TreeOrForestEstimator.scala
@@ -1,8 +1,11 @@
package com.databricks.spark.sql.perf.mllib
import org.apache.spark.ml.{ModelBuilderSSP, Transformer, TreeUtils}
-import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator,
- RegressionEvaluator}
+import org.apache.spark.ml.evaluation.{
+ Evaluator,
+ MulticlassClassificationEvaluator,
+ RegressionEvaluator
+}
import org.apache.spark.sql.DataFrame
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
@@ -10,14 +13,21 @@ import com.databricks.spark.sql.perf.mllib.data.DataGenerator
/** Base trait for BenchmarkAlgorithm objects testing a tree or forest estimator */
private[mllib] trait TreeOrForestEstimator
- extends TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator {
+ extends TestFromTraining
+ with TrainingSetFromTransformer
+ with ScoringWithEvaluator {
self: BenchmarkAlgorithm =>
override protected def initialData(ctx: MLBenchContext) = {
import ctx.params._
val featureArity: Array[Int] = TreeOrForestEstimator.getFeatureArity(ctx)
- val data: DataFrame = DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples,
- ctx.seed(), numPartitions, featureArity)
+ val data: DataFrame = DataGenerator.generateMixedFeatures(
+ ctx.sqlContext,
+ numExamples,
+ ctx.seed(),
+ numPartitions,
+ featureArity
+ )
TreeUtils.setMetadata(data, "features", featureArity)
}
}
@@ -26,49 +36,50 @@ private[mllib] trait TreeOrForestEstimator
private[mllib] trait TreeOrForestClassifier extends TreeOrForestEstimator {
self: BenchmarkAlgorithm =>
- override protected def evaluator(ctx: MLBenchContext): Evaluator = {
+ override protected def evaluator(ctx: MLBenchContext): Evaluator =
new MulticlassClassificationEvaluator()
- }
- override protected def trueModel(ctx: MLBenchContext): Transformer = {
- ModelBuilderSSP.newDecisionTreeClassificationModel(ctx.params.depth, ctx.params.numClasses,
- TreeOrForestEstimator.getFeatureArity(ctx), ctx.seed())
- }
+ override protected def trueModel(ctx: MLBenchContext): Transformer =
+ ModelBuilderSSP.newDecisionTreeClassificationModel(
+ ctx.params.depth,
+ ctx.params.numClasses,
+ TreeOrForestEstimator.getFeatureArity(ctx),
+ ctx.seed()
+ )
}
/** Base trait for BenchmarkAlgorithm objects testing a tree or forest regressor */
private[mllib] trait TreeOrForestRegressor extends TreeOrForestEstimator {
self: BenchmarkAlgorithm =>
- override protected def evaluator(ctx: MLBenchContext): Evaluator = {
+ override protected def evaluator(ctx: MLBenchContext): Evaluator =
new RegressionEvaluator()
- }
- override protected def trueModel(ctx: MLBenchContext): Transformer = {
- ModelBuilderSSP.newDecisionTreeRegressionModel(ctx.params.depth,
- TreeOrForestEstimator.getFeatureArity(ctx), ctx.seed())
- }
+ override protected def trueModel(ctx: MLBenchContext): Transformer =
+ ModelBuilderSSP.newDecisionTreeRegressionModel(
+ ctx.params.depth,
+ TreeOrForestEstimator.getFeatureArity(ctx),
+ ctx.seed()
+ )
}
private[mllib] object TreeOrForestEstimator {
- /**
- * Get feature arity for tree and tree ensemble tests.
- * Currently, this is hard-coded as:
- * - 1/4 binary features
- * - 1/4 high-arity (20-category) features
- * - 1/2 continuous features
- *
- * @return Array of length numFeatures, where 0 indicates continuous feature and
- * value > 0 indicates a categorical feature of that arity.
- */
+ /** Get feature arity for tree and tree ensemble tests. Currently, this is hard-coded as:
+ * - 1/4 binary features
+ * - 1/4 high-arity (20-category) features
+ * - 1/2 continuous features
+ *
+ * @return
+ * Array of length numFeatures, where 0 indicates continuous feature and value > 0 indicates a
+ * categorical feature of that arity.
+ */
def getFeatureArity(ctx: MLBenchContext): Array[Int] = {
- val numFeatures = ctx.params.numFeatures
+ val numFeatures = ctx.params.numFeatures
val fourthFeatures = numFeatures / 4
- Array.fill[Int](fourthFeatures)(2) ++ // low-arity categorical
- Array.fill[Int](fourthFeatures)(20) ++ // high-arity categorical
+ Array.fill[Int](fourthFeatures)(2) ++ // low-arity categorical
+ Array.fill[Int](fourthFeatures)(20) ++ // high-arity categorical
Array.fill[Int](numFeatures - 2 * fourthFeatures)(0) // continuous
}
}
-
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala
index 9580fea7..4f14633e 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala
@@ -14,8 +14,12 @@ object GBTClassification extends BenchmarkAlgorithm with TreeOrForestClassifier
import ctx.params._
// We add +1 to the depth to make it more likely that many iterations of boosting are needed
// to model the true tree.
- ModelBuilderSSP.newDecisionTreeClassificationModel(depth + 1, numClasses, getFeatureArity(ctx),
- ctx.seed())
+ ModelBuilderSSP.newDecisionTreeClassificationModel(
+ depth + 1,
+ numClasses,
+ getFeatureArity(ctx),
+ ctx.seed()
+ )
}
override def getPipelineStage(ctx: MLBenchContext): PipelineStage = {
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/LinearSVC.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/LinearSVC.scala
index 08f139c2..1d536b4a 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/LinearSVC.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/LinearSVC.scala
@@ -9,8 +9,11 @@ import com.databricks.spark.sql.perf.mllib.OptionImplicits._
import com.databricks.spark.sql.perf.mllib._
import com.databricks.spark.sql.perf.mllib.data.DataGenerator
-object LinearSVC extends BenchmarkAlgorithm
- with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator {
+object LinearSVC
+ extends BenchmarkAlgorithm
+ with TestFromTraining
+ with TrainingSetFromTransformer
+ with ScoringWithEvaluator {
override protected def initialData(ctx: MLBenchContext) = {
import ctx.params._
@@ -19,7 +22,8 @@ object LinearSVC extends BenchmarkAlgorithm
numExamples,
ctx.seed(),
numPartitions,
- numFeatures)
+ numFeatures
+ )
}
override protected def trueModel(ctx: MLBenchContext): Transformer = {
@@ -42,4 +46,3 @@ object LinearSVC extends BenchmarkAlgorithm
override protected def evaluator(ctx: MLBenchContext): Evaluator =
new MulticlassClassificationEvaluator()
}
-
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/LogisticRegression.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/LogisticRegression.scala
index 67f0ef62..a382e855 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/LogisticRegression.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/LogisticRegression.scala
@@ -8,9 +8,11 @@ import org.apache.spark.ml.{Estimator, ModelBuilderSSP, PipelineStage, Transform
import org.apache.spark.ml
import org.apache.spark.ml.linalg.Vectors
-
-object LogisticRegression extends BenchmarkAlgorithm
- with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator {
+object LogisticRegression
+ extends BenchmarkAlgorithm
+ with TestFromTraining
+ with TrainingSetFromTransformer
+ with ScoringWithEvaluator {
override protected def initialData(ctx: MLBenchContext) = {
import ctx.params._
@@ -19,7 +21,8 @@ object LogisticRegression extends BenchmarkAlgorithm
numExamples,
ctx.seed(),
numPartitions,
- numFeatures)
+ numFeatures
+ )
}
override protected def trueModel(ctx: MLBenchContext): Transformer = {
@@ -42,4 +45,3 @@ object LogisticRegression extends BenchmarkAlgorithm
override protected def evaluator(ctx: MLBenchContext): Evaluator =
new MulticlassClassificationEvaluator()
}
-
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/NaiveBayes.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/NaiveBayes.scala
index 6d648f52..34be5ce4 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/NaiveBayes.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/NaiveBayes.scala
@@ -10,8 +10,11 @@ import com.databricks.spark.sql.perf.mllib._
import com.databricks.spark.sql.perf.mllib.data.DataGenerator
/** Object containing methods used in performance tests for (multinomial) NaiveBayesModels */
-object NaiveBayes extends BenchmarkAlgorithm
- with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator {
+object NaiveBayes
+ extends BenchmarkAlgorithm
+ with TestFromTraining
+ with TrainingSetFromTransformer
+ with ScoringWithEvaluator {
override protected def initialData(ctx: MLBenchContext) = {
import ctx.params._
@@ -25,7 +28,8 @@ object NaiveBayes extends BenchmarkAlgorithm
numExamples,
ctx.seed(),
numPartitions,
- featureArity)
+ featureArity
+ )
}
override protected def trueModel(ctx: MLBenchContext): Transformer = {
@@ -35,21 +39,23 @@ object NaiveBayes extends BenchmarkAlgorithm
// theta = log of class conditional probabilities, whose dimension is C (number of classes)
// by D (number of features)
val unnormalizedProbs = 0.until(numClasses).map(_ => rng.nextDouble() + 1e-5).toArray
- val logProbSum = math.log(unnormalizedProbs.sum)
- val piArray = unnormalizedProbs.map(prob => math.log(prob) - logProbSum)
+ val logProbSum = math.log(unnormalizedProbs.sum)
+ val piArray = unnormalizedProbs.map(prob => math.log(prob) - logProbSum)
// For class i, set the class-conditional probability of feature i to 0.7, and split up the
// remaining probability mass across the other features
val currClassProb = 0.7
- val thetaArray = Array.tabulate(numClasses) { i: Int =>
- val baseProbMass = (1 - currClassProb) / (numFeatures - 1)
- val probs = Array.fill[Double](numFeatures)(baseProbMass)
- probs(i) = currClassProb
- probs
- }.map(_.map(math.log))
+ val thetaArray = Array
+ .tabulate(numClasses) { i: Int =>
+ val baseProbMass = (1 - currClassProb) / (numFeatures - 1)
+ val probs = Array.fill[Double](numFeatures)(baseProbMass)
+ probs(i) = currClassProb
+ probs
+ }
+ .map(_.map(math.log))
// Initialize new Naive Bayes model
- val pi = Vectors.dense(piArray)
+ val pi = Vectors.dense(piArray)
val theta = new DenseMatrix(numClasses, numFeatures, thetaArray.flatten, true)
ModelBuilderSSP.newNaiveBayesModel(pi, theta)
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/RandomForestClassification.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/RandomForestClassification.scala
index cfb1a953..1985d5b6 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/RandomForestClassification.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/RandomForestClassification.scala
@@ -6,7 +6,6 @@ import org.apache.spark.ml.classification.RandomForestClassifier
import com.databricks.spark.sql.perf.mllib._
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
-
object RandomForestClassification extends BenchmarkAlgorithm with TreeOrForestClassifier {
override def getPipelineStage(ctx: MLBenchContext): PipelineStage = {
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/GaussianMixture.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/GaussianMixture.scala
index 3c684a7b..10f77065 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/GaussianMixture.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/GaussianMixture.scala
@@ -12,9 +12,14 @@ object GaussianMixture extends BenchmarkAlgorithm with TestFromTraining {
override def trainingDataSet(ctx: MLBenchContext): DataFrame = {
import ctx.params._
- DataGenerator.generateGaussianMixtureData(ctx.sqlContext, numCenters = k,
- numExamples = numExamples, seed = ctx.seed(), numPartitions = numPartitions,
- numFeatures = numFeatures)
+ DataGenerator.generateGaussianMixtureData(
+ ctx.sqlContext,
+ numCenters = k,
+ numExamples = numExamples,
+ seed = ctx.seed(),
+ numPartitions = numPartitions,
+ numFeatures = numFeatures
+ )
}
override def getPipelineStage(ctx: MLBenchContext): PipelineStage = {
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/KMeans.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/KMeans.scala
index 9b2f2331..a8d43427 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/KMeans.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/KMeans.scala
@@ -1,20 +1,25 @@
package com.databricks.spark.sql.perf.mllib.clustering
import org.apache.spark.ml
-import org.apache.spark.ml.{PipelineStage}
+import org.apache.spark.ml.PipelineStage
import org.apache.spark.sql._
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
import com.databricks.spark.sql.perf.mllib.data.DataGenerator
import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext, TestFromTraining}
-
object KMeans extends BenchmarkAlgorithm with TestFromTraining {
override def trainingDataSet(ctx: MLBenchContext): DataFrame = {
import ctx.params._
- DataGenerator.generateGaussianMixtureData(ctx.sqlContext, k, numExamples, ctx.seed(),
- numPartitions, numFeatures)
+ DataGenerator.generateGaussianMixtureData(
+ ctx.sqlContext,
+ k,
+ numExamples,
+ ctx.seed(),
+ numPartitions,
+ numFeatures
+ )
}
override def getPipelineStage(ctx: MLBenchContext): PipelineStage = {
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/LDA.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/LDA.scala
index dfe9a2bc..ce4cea13 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/LDA.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/LDA.scala
@@ -13,7 +13,6 @@ import org.apache.spark.ml.linalg.{Vector, Vectors}
import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext, TestFromTraining}
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
-
object LDA extends BenchmarkAlgorithm with TestFromTraining {
// The LDA model is package private, no need to expose it.
@@ -24,13 +23,13 @@ object LDA extends BenchmarkAlgorithm with TestFromTraining {
numPartitions
)
val seed: Int = randomSeed
- val docLen = docLength.get
- val numVocab = vocabSize.get
+ val docLen = docLength.get
+ val numVocab = vocabSize.get
val data: RDD[(Long, Vector)] = rdd.mapPartitionsWithIndex { (idx, partition) =>
val rng = new Well19937c(seed ^ idx)
partition.map { docIndex =>
var currentSize = 0
- val entries = MHashMap[Int, Int]()
+ val entries = MHashMap[Int, Int]()
while (currentSize < docLen) {
val index = rng.nextInt(numVocab)
entries(index) = entries.getOrElse(index, 0) + 1
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/data/ItemSetGenerator.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/data/ItemSetGenerator.scala
index ec47b873..c04ba75d 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/data/ItemSetGenerator.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/data/ItemSetGenerator.scala
@@ -4,17 +4,15 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.mllib.random.{PoissonGenerator, RandomDataGenerator}
-class ItemSetGenerator(
- val numItems: Int,
- val avgItemSetSize: Int)
- extends RandomDataGenerator[Array[String]] {
+class ItemSetGenerator(val numItems: Int, val avgItemSetSize: Int)
+ extends RandomDataGenerator[Array[String]] {
assert(avgItemSetSize > 2)
assert(numItems > 2)
- private val rng = new java.util.Random()
+ private val rng = new java.util.Random()
private val itemSetSizeRng = new PoissonGenerator(avgItemSetSize - 2)
- private val itemRng = new PoissonGenerator(numItems / 2.0)
+ private val itemRng = new PoissonGenerator(numItems / 2.0)
override def setSeed(seed: Long) {
rng.setSeed(seed)
@@ -24,15 +22,18 @@ class ItemSetGenerator(
override def nextValue(): Array[String] = {
// 1. generate size of itemset
- val size = DataGenUtil.nextPoisson(itemSetSizeRng, v => v >= 1 && v <= numItems).toInt
+ val size = DataGenUtil.nextPoisson(itemSetSizeRng, v => v >= 1 && v <= numItems).toInt
val arrayBuff = new ArrayBuffer[Int](size + 2)
// 2. generate items in the itemset
var i = 0
while (i < size) {
- val nextVal = DataGenUtil.nextPoisson(itemRng, (item: Double) => {
- item >= 0 && item < numItems && !arrayBuff.contains(item)
- }).toInt
+ val nextVal = DataGenUtil
+ .nextPoisson(
+ itemRng,
+ (item: Double) => item >= 0 && item < numItems && !arrayBuff.contains(item)
+ )
+ .toInt
arrayBuff.append(nextVal)
i += 1
}
@@ -54,6 +55,5 @@ class ItemSetGenerator(
arrayBuff.map(_.toString).toArray
}
- override def copy(): ItemSetGenerator
- = new ItemSetGenerator(numItems, avgItemSetSize)
+ override def copy(): ItemSetGenerator = new ItemSetGenerator(numItems, avgItemSetSize)
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/data/RatingGenerator.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/data/RatingGenerator.scala
index b1b197ac..58aac5d9 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/data/RatingGenerator.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/data/RatingGenerator.scala
@@ -8,20 +8,20 @@ import scala.collection.mutable
class RatingGenerator(
private val numUsers: Int,
private val numProducts: Int,
- private val implicitPrefs: Boolean) extends RandomDataGenerator[Rating[Int]] {
+ private val implicitPrefs: Boolean
+) extends RandomDataGenerator[Rating[Int]] {
private val rng = new java.util.Random()
private val observed = new mutable.HashMap[(Int, Int), Boolean]()
override def nextValue(): Rating[Int] = {
- var tuple = (rng.nextInt(numUsers),rng.nextInt(numProducts))
- while (observed.getOrElse(tuple,false)){
- tuple = (rng.nextInt(numUsers),rng.nextInt(numProducts))
- }
+ var tuple = (rng.nextInt(numUsers), rng.nextInt(numProducts))
+ while (observed.getOrElse(tuple, false))
+ tuple = (rng.nextInt(numUsers), rng.nextInt(numProducts))
observed += (tuple -> true)
- val rating = if (implicitPrefs) rng.nextInt(2)*1.0 else rng.nextDouble()*5
+ val rating = if (implicitPrefs) rng.nextInt(2) * 1.0 else rng.nextDouble() * 5
new Rating(tuple._1, tuple._2, rating.toFloat)
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/data/dataGeneration.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/data/dataGeneration.scala
index d2838156..7606096b 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/data/dataGeneration.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/data/dataGeneration.scala
@@ -14,44 +14,61 @@ object DataGenerator {
numExamples: Long,
seed: Long,
numPartitions: Int,
- numFeatures: Int): DataFrame = {
+ numFeatures: Int
+ ): DataFrame = {
val featureArity = Array.fill[Int](numFeatures)(0)
- val rdd: RDD[Vector] = RandomRDDs.randomRDD(sql.sparkContext,
- new FeaturesGenerator(featureArity), numExamples, numPartitions, seed)
- sql.createDataFrame(rdd.map(Tuple1.apply)).toDF("features")
+ val rdd: RDD[Vector] = RandomRDDs.randomRDD(
+ sql.sparkSession.sparkContext,
+ new FeaturesGenerator(featureArity),
+ numExamples,
+ numPartitions,
+ seed
+ )
+ sql.sparkSession.createDataFrame(rdd.map(Tuple1.apply)).toDF("features")
}
- /**
- * Generate a mix of continuous and categorical features.
- * @param featureArity Array of length numFeatures, where 0 indicates a continuous feature and
- * a value > 0 indicates a categorical feature with that arity.
- */
+ /** Generate a mix of continuous and categorical features.
+ * @param featureArity
+ * Array of length numFeatures, where 0 indicates a continuous feature and a value > 0
+ * indicates a categorical feature with that arity.
+ */
def generateMixedFeatures(
sql: SQLContext,
numExamples: Long,
seed: Long,
numPartitions: Int,
- featureArity: Array[Int]): DataFrame = {
- val rdd: RDD[Vector] = RandomRDDs.randomRDD(sql.sparkContext,
- new FeaturesGenerator(featureArity), numExamples, numPartitions, seed)
- sql.createDataFrame(rdd.map(Tuple1.apply)).toDF("features")
+ featureArity: Array[Int]
+ ): DataFrame = {
+ val rdd: RDD[Vector] = RandomRDDs.randomRDD(
+ sql.sparkSession.sparkContext,
+ new FeaturesGenerator(featureArity),
+ numExamples,
+ numPartitions,
+ seed
+ )
+ sql.sparkSession.createDataFrame(rdd.map(Tuple1.apply)).toDF("features")
}
- /**
- * Generate data from a Gaussian mixture model.
- * @param numCenters Number of clusters in mixture
- */
+ /** Generate data from a Gaussian mixture model.
+ * @param numCenters
+ * Number of clusters in mixture
+ */
def generateGaussianMixtureData(
sql: SQLContext,
numCenters: Int,
numExamples: Long,
seed: Long,
numPartitions: Int,
- numFeatures: Int): DataFrame = {
- val rdd: RDD[Vector] = RandomRDDs.randomRDD(sql.sparkContext,
- new GaussianMixtureDataGenerator(numCenters, numFeatures, seed), numExamples, numPartitions,
- seed)
- sql.createDataFrame(rdd.map(Tuple1.apply)).toDF("features")
+ numFeatures: Int
+ ): DataFrame = {
+ val rdd: RDD[Vector] = RandomRDDs.randomRDD(
+ sql.sparkSession.sparkContext,
+ new GaussianMixtureDataGenerator(numCenters, numFeatures, seed),
+ numExamples,
+ numPartitions,
+ seed
+ )
+ sql.sparkSession.createDataFrame(rdd.map(Tuple1.apply)).toDF("features")
}
def generateRatings(
@@ -62,20 +79,31 @@ object DataGenerator {
numTestExamples: Long,
implicitPrefs: Boolean,
numPartitions: Int,
- seed: Long): (DataFrame, DataFrame) = {
-
- val sc = sql.sparkContext
- val train = RandomRDDs.randomRDD(sc,
+ seed: Long
+ ): (DataFrame, DataFrame) = {
+
+ val sc = sql.sparkSession.sparkContext
+ val train = RandomRDDs
+ .randomRDD(
+ sc,
+ new RatingGenerator(numUsers, numProducts, implicitPrefs),
+ numExamples,
+ numPartitions,
+ seed
+ )
+ .cache()
+
+ val test = RandomRDDs.randomRDD(
+ sc,
new RatingGenerator(numUsers, numProducts, implicitPrefs),
- numExamples, numPartitions, seed).cache()
-
- val test = RandomRDDs.randomRDD(sc,
- new RatingGenerator(numUsers, numProducts, implicitPrefs),
- numTestExamples, numPartitions, seed + 24)
+ numTestExamples,
+ numPartitions,
+ seed + 24
+ )
// Now get rid of duplicate ratings and remove non-existant userID's
// and prodID's from the test set
- val commons: PairRDDFunctions[(Int,Int),Rating[Int]] =
+ val commons: PairRDDFunctions[(Int, Int), Rating[Int]] =
new PairRDDFunctions(train.keyBy(rating => (rating.user, rating.item)).cache())
val exact = commons.join(test.keyBy(rating => (rating.user, rating.item)))
@@ -83,15 +111,15 @@ object DataGenerator {
val trainPruned = commons.subtractByKey(exact).map(_._2).cache()
// Now get rid of users that don't exist in the train set
- val trainUsers: RDD[(Int,Rating[Int])] = trainPruned.keyBy(rating => rating.user)
- val testUsers: PairRDDFunctions[Int,Rating[Int]] =
+ val trainUsers: RDD[(Int, Rating[Int])] = trainPruned.keyBy(rating => rating.user)
+ val testUsers: PairRDDFunctions[Int, Rating[Int]] =
new PairRDDFunctions(test.keyBy(rating => rating.user))
val testWithAdditionalUsers = testUsers.subtractByKey(trainUsers)
- val userPrunedTestProds: RDD[(Int,Rating[Int])] =
+ val userPrunedTestProds: RDD[(Int, Rating[Int])] =
testUsers.subtractByKey(testWithAdditionalUsers).map(_._2).keyBy(rating => rating.item)
- val trainProds: RDD[(Int,Rating[Int])] = trainPruned.keyBy(rating => rating.item)
+ val trainProds: RDD[(Int, Rating[Int])] = trainPruned.keyBy(rating => rating.item)
val testWithAdditionalProds =
new PairRDDFunctions[Int, Rating[Int]](userPrunedTestProds).subtractByKey(trainProds)
@@ -100,7 +128,7 @@ object DataGenerator {
.subtractByKey(testWithAdditionalProds)
.map(_._2)
- (sql.createDataFrame(trainPruned), sql.createDataFrame(finalTest))
+ (sql.sparkSession.createDataFrame(trainPruned), sql.sparkSession.createDataFrame(finalTest))
}
def generateRandString(
@@ -109,10 +137,16 @@ object DataGenerator {
seed: Long,
numPartitions: Int,
distinctCount: Int,
- dataColName: String): DataFrame = {
- val rdd: RDD[String] = RandomRDDs.randomRDD(sql.sparkContext,
- new RandStringGenerator(distinctCount), numExamples, numPartitions, seed)
- sql.createDataFrame(rdd.map(Tuple1.apply)).toDF(dataColName)
+ dataColName: String
+ ): DataFrame = {
+ val rdd: RDD[String] = RandomRDDs.randomRDD(
+ sql.sparkSession.sparkContext,
+ new RandStringGenerator(distinctCount),
+ numExamples,
+ numPartitions,
+ seed
+ )
+ sql.sparkSession.createDataFrame(rdd.map(Tuple1.apply)).toDF(dataColName)
}
def generateDoc(
@@ -122,10 +156,16 @@ object DataGenerator {
numPartitions: Int,
vocabSize: Int,
avgDocLength: Int,
- dataColName: String): DataFrame = {
- val rdd: RDD[String] = RandomRDDs.randomRDD(sql.sparkContext,
- new DocGenerator(vocabSize, avgDocLength), numExamples, numPartitions, seed)
- sql.createDataFrame(rdd.map(Tuple1.apply)).toDF(dataColName)
+ dataColName: String
+ ): DataFrame = {
+ val rdd: RDD[String] = RandomRDDs.randomRDD(
+ sql.sparkSession.sparkContext,
+ new DocGenerator(vocabSize, avgDocLength),
+ numExamples,
+ numPartitions,
+ seed
+ )
+ sql.sparkSession.createDataFrame(rdd.map(Tuple1.apply)).toDF(dataColName)
}
def generateItemSet(
@@ -134,42 +174,44 @@ object DataGenerator {
seed: Long,
numPartitions: Int,
numItems: Int,
- avgItemSetSize: Int): DataFrame = {
+ avgItemSetSize: Int
+ ): DataFrame = {
val rdd: RDD[Array[String]] = RandomRDDs.randomRDD(
- sql.sparkContext,
+ sql.sparkSession.sparkContext,
new ItemSetGenerator(numItems, avgItemSetSize),
numExamples,
numPartitions,
- seed)
- sql.createDataFrame(rdd.map(Tuple1.apply)).toDF("items")
+ seed
+ )
+ sql.sparkSession.createDataFrame(rdd.map(Tuple1.apply)).toDF("items")
}
}
-
-/**
- * Generator for a feature vector which can include a mix of categorical and continuous features.
- *
- * @param featureArity Length numFeatures, where 0 indicates continuous feature and > 0
- * indicates a categorical feature of that arity.
- */
-class FeaturesGenerator(val featureArity: Array[Int])
- extends RandomDataGenerator[Vector] {
+/** Generator for a feature vector which can include a mix of categorical and continuous features.
+ *
+ * @param featureArity
+ * Length numFeatures, where 0 indicates continuous feature and > 0 indicates a categorical
+ * feature of that arity.
+ */
+class FeaturesGenerator(val featureArity: Array[Int]) extends RandomDataGenerator[Vector] {
featureArity.foreach { arity =>
- require(arity >= 0, s"FeaturesGenerator given categorical arity = $arity, " +
- s"but arity should be >= 0.")
+ require(
+ arity >= 0,
+ s"FeaturesGenerator given categorical arity = $arity, " +
+ s"but arity should be >= 0."
+ )
}
val numFeatures = featureArity.length
private val rng = new java.util.Random()
- /**
- * Generates vector with features in the order given by [[featureArity]]
- */
+ /** Generates vector with features in the order given by [[featureArity]]
+ */
override def nextValue(): Vector = {
val arr = new Array[Double](numFeatures)
- var j = 0
+ var j = 0
while (j < featureArity.length) {
if (featureArity(j) == 0)
arr(j) = 2 * rng.nextDouble() - 1 // centered uniform data
@@ -187,37 +229,33 @@ class FeaturesGenerator(val featureArity: Array[Int])
override def copy(): FeaturesGenerator = new FeaturesGenerator(featureArity)
}
+/** Generate data from a Gaussian mixture model.
+ */
+class GaussianMixtureDataGenerator(val numCenters: Int, val numFeatures: Int, val seed: Long)
+ extends RandomDataGenerator[Vector] {
-/**
- * Generate data from a Gaussian mixture model.
- */
-class GaussianMixtureDataGenerator(
- val numCenters: Int,
- val numFeatures: Int,
- val seed: Long) extends RandomDataGenerator[Vector] {
-
- private val rng = new java.util.Random(seed)
- private val rng2 = new java.util.Random(seed + 24)
+ private val rng = new java.util.Random(seed)
+ private val rng2 = new java.util.Random(seed + 24)
private val scale_factors = Array.fill(numCenters)(rng.nextInt(20) - 10)
// Have a random number of points around a cluster
private val concentrations: Seq[Double] = {
- val rand = Array.fill(numCenters)(rng.nextDouble())
+ val rand = Array.fill(numCenters)(rng.nextDouble())
val randSum = rand.sum
- val scaled = rand.map(x => x / randSum)
+ val scaled = rand.map(x => x / randSum)
- (1 to numCenters).map{i =>
+ (1 to numCenters).map { i =>
scaled.slice(0, i).sum
}
}
- private val centers = (0 until numCenters).map{i =>
+ private val centers = (0 until numCenters).map { i =>
Array.fill(numFeatures)((2 * rng.nextDouble() - 1) * scale_factors(i))
}
override def nextValue(): Vector = {
val pick_center_rand = rng2.nextDouble()
- val center = centers(concentrations.indexWhere(p => pick_center_rand <= p))
+ val center = centers(concentrations.indexWhere(p => pick_center_rand <= p))
Vectors.dense(Array.tabulate(numFeatures)(i => center(i) + rng2.nextGaussian()))
}
@@ -230,14 +268,12 @@ class GaussianMixtureDataGenerator(
new GaussianMixtureDataGenerator(numCenters, numFeatures, seed)
}
-class RandStringGenerator(
- distinctCount: Int) extends RandomDataGenerator[String] {
+class RandStringGenerator(distinctCount: Int) extends RandomDataGenerator[String] {
private val rng = new java.util.Random()
- override def nextValue(): String = {
+ override def nextValue(): String =
rng.nextInt(distinctCount).toString
- }
override def setSeed(seed: Long) {
rng.setSeed(seed)
@@ -246,12 +282,10 @@ class RandStringGenerator(
override def copy(): RandStringGenerator = new RandStringGenerator(distinctCount)
}
-class DocGenerator(
- vocabSize: Int,
- avgDocLength: Int,
- maxDocLength: Int = 65535) extends RandomDataGenerator[String] {
+class DocGenerator(vocabSize: Int, avgDocLength: Int, maxDocLength: Int = 65535)
+ extends RandomDataGenerator[String] {
- private val wordRng = new java.util.Random()
+ private val wordRng = new java.util.Random()
private val docLengthRng = new PoissonGenerator(avgDocLength)
override def setSeed(seed: Long) {
@@ -261,7 +295,7 @@ class DocGenerator(
override def nextValue(): String = {
val docLength = DataGenUtil.nextPoisson(docLengthRng, v => v > 0 && v <= maxDocLength).toInt
- val sb = new StringBuffer()
+ val sb = new StringBuffer()
var i = 0
while (i < docLength) {
@@ -279,9 +313,7 @@ class DocGenerator(
object DataGenUtil {
def nextPoisson(rng: PoissonGenerator, condition: Double => Boolean): Double = {
var value = 0.0
- do {
- value = rng.nextValue()
- } while (!condition(value))
+ do value = rng.nextValue() while (!condition(value))
value
}
-}
\ No newline at end of file
+}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/Bucketizer.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/Bucketizer.scala
index 789aba9e..3a509d22 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/Bucketizer.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/Bucketizer.scala
@@ -19,19 +19,33 @@ object Bucketizer extends BenchmarkAlgorithm with TestFromTraining with UnaryTra
import ctx.sqlContext.implicits._
val rng = ctx.newGenerator()
// For a bucketizer, training data consists of a single column of random doubles
- DataGenerator.generateContinuousFeatures(ctx.sqlContext,
- numExamples, ctx.seed(), numPartitions, numFeatures = 1).rdd.map { case Row(vec: Vector) =>
- vec(0) // extract the single generated double value for each row
- }.toDF(inputCol)
+ DataGenerator
+ .generateContinuousFeatures(
+ ctx.sqlContext,
+ numExamples,
+ ctx.seed(),
+ numPartitions,
+ numFeatures = 1
+ )
+ .rdd
+ .map {
+ case Row(vec: Vector) =>
+ vec(0) // extract the single generated double value for each row
+ }
+ .toDF(inputCol)
}
override def getPipelineStage(ctx: MLBenchContext): PipelineStage = {
import ctx.params._
val rng = ctx.newGenerator()
// Generate an array of (finite) splitting points in [-1, 1) for the Bucketizer
- val splitPoints = 0.until(bucketizerNumBuckets - 1).map { _ =>
- 2 * rng.nextDouble() - 1
- }.sorted.toArray
+ val splitPoints = 0
+ .until(bucketizerNumBuckets - 1)
+ .map { _ =>
+ 2 * rng.nextDouble() - 1
+ }
+ .sorted
+ .toArray
// Final array of splits contains +/- infinity
val splits = Array(Double.NegativeInfinity) ++ splitPoints ++ Array(Double.PositiveInfinity)
new ml.feature.Bucketizer()
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/HashingTF.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/HashingTF.scala
index 5fb7d76a..536a8740 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/HashingTF.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/HashingTF.scala
@@ -15,7 +15,7 @@ object HashingTF extends BenchmarkAlgorithm with TestFromTraining with UnaryTran
// Sample a random sentence of length up to maxLen from the provided array of words
private def randomSentence(rng: Random, maxLen: Int, dictionary: Array[String]): Array[String] = {
- val length = rng.nextInt(maxLen - 1) + 1
+ val length = rng.nextInt(maxLen - 1) + 1
val dictLength = dictionary.length
Array.tabulate[String](length)(_ => dictionary(rng.nextInt(dictLength)))
}
@@ -26,9 +26,15 @@ object HashingTF extends BenchmarkAlgorithm with TestFromTraining with UnaryTran
// each string is selected from a pool of vocabSize strings
// The expected # of occurrences of each word in our vocabulary is
// (docLength * numExamples) / vocabSize
- val df = DataGenerator.generateDoc(ctx.sqlContext, numExamples = numExamples, seed = ctx.seed(),
- numPartitions = numPartitions, vocabSize = vocabSize, avgDocLength = docLength,
- dataColName = inputCol)
+ val df = DataGenerator.generateDoc(
+ ctx.sqlContext,
+ numExamples = numExamples,
+ seed = ctx.seed(),
+ numPartitions = numPartitions,
+ vocabSize = vocabSize,
+ avgDocLength = docLength,
+ dataColName = inputCol
+ )
df.withColumn(inputCol, split(df(inputCol), " "))
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/OneHotEncoder.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/OneHotEncoder.scala
index 9ad4ceba..aacd8882 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/OneHotEncoder.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/OneHotEncoder.scala
@@ -16,19 +16,23 @@ object OneHotEncoder extends BenchmarkAlgorithm with TestFromTraining with Unary
import ctx.params._
import ctx.sqlContext.implicits._
- DataGenerator.generateMixedFeatures(
- ctx.sqlContext,
- numExamples,
- ctx.seed(),
- numPartitions,
- Array.fill(1)(featureArity.get)
- ).rdd.map { case Row(vec: Vector) =>
- vec(0) // extract the single generated double value for each row
- }.toDF(inputCol)
+ DataGenerator
+ .generateMixedFeatures(
+ ctx.sqlContext,
+ numExamples,
+ ctx.seed(),
+ numPartitions,
+ Array.fill(1)(featureArity.get)
+ )
+ .rdd
+ .map {
+ case Row(vec: Vector) =>
+ vec(0) // extract the single generated double value for each row
+ }
+ .toDF(inputCol)
}
- override def getPipelineStage(ctx: MLBenchContext): PipelineStage = {
+ override def getPipelineStage(ctx: MLBenchContext): PipelineStage =
new ml.feature.OneHotEncoder()
.setInputCol(inputCol)
- }
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/QuantileDiscretizer.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/QuantileDiscretizer.scala
index cf32b0f9..00d08256 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/QuantileDiscretizer.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/QuantileDiscretizer.scala
@@ -16,15 +16,20 @@ object QuantileDiscretizer extends BenchmarkAlgorithm with TestFromTraining with
import ctx.params._
import ctx.sqlContext.implicits._
- DataGenerator.generateContinuousFeatures(
- ctx.sqlContext,
- numExamples,
- ctx.seed(),
- numPartitions,
- 1
- ).rdd.map { case Row(vec: Vector) =>
- vec(0) // extract the single generated double value for each row
- }.toDF(inputCol)
+ DataGenerator
+ .generateContinuousFeatures(
+ ctx.sqlContext,
+ numExamples,
+ ctx.seed(),
+ numPartitions,
+ 1
+ )
+ .rdd
+ .map {
+ case Row(vec: Vector) =>
+ vec(0) // extract the single generated double value for each row
+ }
+ .toDF(inputCol)
}
override def getPipelineStage(ctx: MLBenchContext): PipelineStage = {
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/StringIndexer.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/StringIndexer.scala
index 852cefa4..ca42f773 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/StringIndexer.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/StringIndexer.scala
@@ -16,12 +16,14 @@ object StringIndexer extends BenchmarkAlgorithm with TestFromTraining with Unary
import ctx.params._
import ctx.sqlContext.implicits._
- DataGenerator.generateRandString(ctx.sqlContext,
+ DataGenerator.generateRandString(
+ ctx.sqlContext,
numExamples,
ctx.seed(),
numPartitions,
vocabSize,
- inputCol)
+ inputCol
+ )
}
override def getPipelineStage(ctx: MLBenchContext): PipelineStage = {
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/Tokenizer.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/Tokenizer.scala
index aa066661..b3e863e9 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/Tokenizer.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/Tokenizer.scala
@@ -23,11 +23,11 @@ object Tokenizer extends BenchmarkAlgorithm with TestFromTraining with UnaryTran
numPartitions,
vocabSize,
docLength,
- inputCol)
+ inputCol
+ )
}
- override def getPipelineStage(ctx: MLBenchContext): PipelineStage = {
+ override def getPipelineStage(ctx: MLBenchContext): PipelineStage =
new ml.feature.Tokenizer()
.setInputCol(inputCol)
- }
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/UnaryTransformer.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/UnaryTransformer.scala
index bd7b3cc3..23c0afd1 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/UnaryTransformer.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/UnaryTransformer.scala
@@ -2,6 +2,6 @@ package com.databricks.spark.sql.perf.mllib.feature
/** Trait defining common state/methods for featurizers taking a single input col */
private[feature] trait UnaryTransformer {
- private[feature] val inputCol = "inputCol"
+ private[feature] val inputCol = "inputCol"
private[feature] val outputCol = "outputCol"
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/VectorAssembler.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/VectorAssembler.scala
index 66897d97..dc16ee47 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/VectorAssembler.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/VectorAssembler.scala
@@ -13,29 +13,31 @@ import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext,
/** Object for testing VectorAssembler performance */
object VectorAssembler extends BenchmarkAlgorithm with TestFromTraining {
- private def getInputCols(numInputCols: Int): Array[String] = {
- Array.tabulate(numInputCols)(i => s"c${i}")
- }
+ private def getInputCols(numInputCols: Int): Array[String] =
+ Array.tabulate(numInputCols)(i => s"c$i")
override def trainingDataSet(ctx: MLBenchContext): DataFrame = {
import ctx.params._
- require(numInputCols.get <= numFeatures.get,
- s"numInputCols (${numInputCols}) cannot be greater than numFeatures (${numFeatures}).")
+ require(
+ numInputCols.get <= numFeatures.get,
+ s"numInputCols ($numInputCols) cannot be greater than numFeatures ($numFeatures)."
+ )
val df = DataGenerator.generateContinuousFeatures(
ctx.sqlContext,
numExamples,
ctx.seed(),
numPartitions,
- numFeatures)
+ numFeatures
+ )
val slice = udf { (v: Vector, numSlices: Int) =>
val data = v.toArray
- val n = data.length.toLong
+ val n = data.length.toLong
(0 until numSlices).map { i =>
val start = ((i * n) / numSlices).toInt
- val end = ((i + 1) * n / numSlices).toInt
+ val end = ((i + 1) * n / numSlices).toInt
Vectors.dense(data.slice(start, end))
}
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/Word2Vec.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/Word2Vec.scala
index a59d29e5..199baba4 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/Word2Vec.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/Word2Vec.scala
@@ -31,22 +31,20 @@ object Word2Vec extends BenchmarkAlgorithm with TestFromTraining {
df.select(split(col("text"), " ").as("text"))
}
- override def getPipelineStage(ctx: MLBenchContext): PipelineStage = {
+ override def getPipelineStage(ctx: MLBenchContext): PipelineStage =
new ml.feature.Word2Vec().setInputCol("text")
- }
override def testAdditionalMethods(
ctx: MLBenchContext,
- model: Transformer): Map[String, () => _] = {
+ model: Transformer
+ ): Map[String, () => _] = {
import ctx.params._
- val rng = new Random(ctx.seed())
+ val rng = new Random(ctx.seed())
val word2vecModel = model.asInstanceOf[Word2VecModel]
- val testWord = Vectors.dense(Array.fill(word2vecModel.getVectorSize)(rng.nextGaussian()))
+ val testWord = Vectors.dense(Array.fill(word2vecModel.getVectorSize)(rng.nextGaussian()))
- Map("findSynonyms" -> (() => {
- word2vecModel.findSynonyms(testWord, numSynonymsToFind)
- }))
+ Map("findSynonyms" -> (() => word2vecModel.findSynonyms(testWord, numSynonymsToFind)))
}
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/fpm/FPGrowth.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/fpm/FPGrowth.scala
index 691bf5bd..ec519e18 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/fpm/FPGrowth.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/fpm/FPGrowth.scala
@@ -9,7 +9,6 @@ import com.databricks.spark.sql.perf.mllib._
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
import com.databricks.spark.sql.perf.mllib.data.DataGenerator
-
/** Object containing methods used in performance tests for FPGrowth */
object FPGrowth extends BenchmarkAlgorithm with TestFromTraining {
@@ -22,21 +21,20 @@ object FPGrowth extends BenchmarkAlgorithm with TestFromTraining {
ctx.seed(),
numPartitions,
numItems,
- itemSetSize)
+ itemSetSize
+ )
}
- override def getPipelineStage(ctx: MLBenchContext): PipelineStage = {
+ override def getPipelineStage(ctx: MLBenchContext): PipelineStage =
new ml.fpm.FPGrowth()
.setItemsCol("items")
- }
override def testAdditionalMethods(
ctx: MLBenchContext,
- model: Transformer): Map[String, () => _] = {
+ model: Transformer
+ ): Map[String, () => _] = {
val fpModel = model.asInstanceOf[FPGrowthModel]
- Map("associationRules" -> (() => {
- fpModel.associationRules.count()
- }))
+ Map("associationRules" -> (() => fpModel.associationRules.count()))
}
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/recommendation/ALS.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/recommendation/ALS.scala
index 9c21947b..609fc515 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/recommendation/ALS.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/recommendation/ALS.scala
@@ -7,34 +7,44 @@ import org.apache.spark.sql._
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
import com.databricks.spark.sql.perf.mllib.data.DataGenerator
-import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext, ScoringWithEvaluator}
+import com.databricks.spark.sql.perf.mllib.{
+ BenchmarkAlgorithm,
+ MLBenchContext,
+ ScoringWithEvaluator
+}
object ALS extends BenchmarkAlgorithm with ScoringWithEvaluator {
override def trainingDataSet(ctx: MLBenchContext): DataFrame = {
import ctx.params._
- DataGenerator.generateRatings(
- ctx.sqlContext,
- numUsers,
- numItems,
- numExamples,
- numTestExamples,
- implicitPrefs = false,
- numPartitions,
- ctx.seed())._1
+ DataGenerator
+ .generateRatings(
+ ctx.sqlContext,
+ numUsers,
+ numItems,
+ numExamples,
+ numTestExamples,
+ implicitPrefs = false,
+ numPartitions,
+ ctx.seed()
+ )
+ ._1
}
override def testDataSet(ctx: MLBenchContext): DataFrame = {
import ctx.params._
- DataGenerator.generateRatings(
- ctx.sqlContext,
- numUsers,
- numItems,
- numExamples,
- numTestExamples,
- implicitPrefs = false,
- numPartitions,
- ctx.seed())._2
+ DataGenerator
+ .generateRatings(
+ ctx.sqlContext,
+ numUsers,
+ numItems,
+ numExamples,
+ numTestExamples,
+ implicitPrefs = false,
+ numPartitions,
+ ctx.seed()
+ )
+ ._2
}
override def getPipelineStage(ctx: MLBenchContext): PipelineStage = {
@@ -47,7 +57,6 @@ object ALS extends BenchmarkAlgorithm with ScoringWithEvaluator {
.setMaxIter(maxIter)
}
- override protected def evaluator(ctx: MLBenchContext): Evaluator = {
+ override protected def evaluator(ctx: MLBenchContext): Evaluator =
new RegressionEvaluator().setLabelCol("rating")
- }
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/DecisionTreeRegression.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/DecisionTreeRegression.scala
index 126ffe4d..dbae2fde 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/DecisionTreeRegression.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/DecisionTreeRegression.scala
@@ -6,7 +6,6 @@ import org.apache.spark.ml.regression.DecisionTreeRegressor
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
import com.databricks.spark.sql.perf.mllib._
-
object DecisionTreeRegression extends BenchmarkAlgorithm with TreeOrForestRegressor {
override def getPipelineStage(ctx: MLBenchContext): PipelineStage = {
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GBTRegression.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GBTRegression.scala
index e78d2eb6..dbc9cc12 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GBTRegression.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GBTRegression.scala
@@ -4,8 +4,11 @@ import org.apache.spark.ml.PipelineStage
import org.apache.spark.ml.regression.GBTRegressor
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
-import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext,
- TreeOrForestRegressor}
+import com.databricks.spark.sql.perf.mllib.{
+ BenchmarkAlgorithm,
+ MLBenchContext,
+ TreeOrForestRegressor
+}
object GBTRegression extends BenchmarkAlgorithm with TreeOrForestRegressor {
override def getPipelineStage(ctx: MLBenchContext): PipelineStage = {
@@ -15,4 +18,4 @@ object GBTRegression extends BenchmarkAlgorithm with TreeOrForestRegressor {
.setMaxIter(maxIter)
.setSeed(ctx.seed())
}
-}
\ No newline at end of file
+}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GLMRegression.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GLMRegression.scala
index c2761a0b..16fca1b2 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GLMRegression.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GLMRegression.scala
@@ -9,9 +9,11 @@ import com.databricks.spark.sql.perf.mllib.OptionImplicits._
import com.databricks.spark.sql.perf.mllib._
import com.databricks.spark.sql.perf.mllib.data.DataGenerator
-
-object GLMRegression extends BenchmarkAlgorithm with TestFromTraining with
- TrainingSetFromTransformer with ScoringWithEvaluator {
+object GLMRegression
+ extends BenchmarkAlgorithm
+ with TestFromTraining
+ with TrainingSetFromTransformer
+ with ScoringWithEvaluator {
override protected def initialData(ctx: MLBenchContext) = {
import ctx.params._
@@ -20,7 +22,8 @@ object GLMRegression extends BenchmarkAlgorithm with TestFromTraining with
numExamples,
ctx.seed(),
numPartitions,
- numFeatures)
+ numFeatures
+ )
}
override protected def trueModel(ctx: MLBenchContext): Transformer = {
@@ -30,7 +33,7 @@ object GLMRegression extends BenchmarkAlgorithm with TestFromTraining with
Vectors.dense(Array.fill[Double](ctx.params.numFeatures)(2 * rng.nextDouble() - 1))
// Small intercept to prevent some skew in the data.
val intercept = 0.01 * (2 * rng.nextDouble - 1)
- val m = ModelBuilderSSP.newGLR(coefficients, intercept)
+ val m = ModelBuilderSSP.newGLR(coefficients, intercept)
m.set(m.link, link.get)
m.set(m.family, family.get)
m
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/LinearRegression.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/LinearRegression.scala
index c2882ce8..29a3b735 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/LinearRegression.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/LinearRegression.scala
@@ -9,9 +9,11 @@ import com.databricks.spark.sql.perf.mllib.OptionImplicits._
import com.databricks.spark.sql.perf.mllib._
import com.databricks.spark.sql.perf.mllib.data.DataGenerator
-
-object LinearRegression extends BenchmarkAlgorithm with TestFromTraining with
- TrainingSetFromTransformer with ScoringWithEvaluator {
+object LinearRegression
+ extends BenchmarkAlgorithm
+ with TestFromTraining
+ with TrainingSetFromTransformer
+ with ScoringWithEvaluator {
override protected def initialData(ctx: MLBenchContext) = {
import ctx.params._
@@ -20,7 +22,8 @@ object LinearRegression extends BenchmarkAlgorithm with TestFromTraining with
numExamples,
ctx.seed(),
numPartitions,
- numFeatures)
+ numFeatures
+ )
}
override protected def trueModel(ctx: MLBenchContext): Transformer = {
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/RandomForestRegression.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/RandomForestRegression.scala
index c9ed4e8d..33ab1f21 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/RandomForestRegression.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/RandomForestRegression.scala
@@ -4,8 +4,11 @@ import org.apache.spark.ml.PipelineStage
import org.apache.spark.ml.regression.RandomForestRegressor
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
-import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext,
- TreeOrForestRegressor}
+import com.databricks.spark.sql.perf.mllib.{
+ BenchmarkAlgorithm,
+ MLBenchContext,
+ TreeOrForestRegressor
+}
object RandomForestRegression extends BenchmarkAlgorithm with TreeOrForestRegressor {
override def getPipelineStage(ctx: MLBenchContext): PipelineStage = {
diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/yaml.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/yaml.scala
index edd54a7c..35c93a49 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/mllib/yaml.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/yaml.scala
@@ -8,105 +8,99 @@ import scala.io.Source
import scala.reflect._
import scala.reflect.runtime.universe._
-import scala.util.{Try => STry, Success, Failure}
+import scala.util.{Failure, Success, Try => STry}
import org.yaml.snakeyaml.Yaml
-import com.databricks.spark.sql.perf.{MLParams}
+import com.databricks.spark.sql.perf.MLParams
-
-/**
- * The configuration information generated from reading a YAML file.
- *
- * @param output the output direct
- */
+/** The configuration information generated from reading a YAML file.
+ *
+ * @param output
+ * the output direct
+ */
case class YamlConfig(
- output: String = "/tmp/result",
- timeout: Duration = 20.minutes,
- runnableBenchmarks: Seq[MLTest])
+ output: String = "/tmp/result",
+ timeout: Duration = 20.minutes,
+ runnableBenchmarks: Seq[MLTest]
+)
object YamlConfig {
- /**
- * Reads a string (assumed to contain a yaml description) and returns the configuration.
- */
+ /** Reads a string (assumed to contain a yaml description) and returns the configuration.
+ */
def readString(s: String): YamlConfig = {
println(s)
- val yaml = new Yaml()
- val m = dict(yaml.load(s))
+ val yaml = new Yaml()
+ val m = dict(yaml.load(s))
val common = m.get("common").map(dict).getOrElse(Map.empty)
println("common")
println(m)
val exps = m("benchmarks")
- .asInstanceOf[AL[Map[String, Any]]].asScala.map(dict).toSeq
+ .asInstanceOf[AL[Map[String, Any]]]
+ .asScala
+ .map(dict)
+ .toSeq
println("exps:")
println(exps)
val experiments = exps.flatMap { sd =>
- val name = sd("name").toString
- val params = sd.get("params").map(dict).getOrElse(Map.empty)
+ val name = sd("name").toString
+ val params = sd.get("params").map(dict).getOrElse(Map.empty)
val expParams = cartesian(common ++ params)
for (c <- expParams) yield name -> c
}
println("exp parsed")
println(experiments)
- val e2 = experiments.map { case (n, e) =>
- val e2 = ccFromMap.fromMap[MLParams](e, strict=true)
- val s = ccFromMap.loadExperiment(n).getOrElse {
- throw new Exception(s"Cannot find algorithm $n in the standard benchmark algorithms")
- }
- MLTest(s, e2)
+ val e2 = experiments.map {
+ case (n, e) =>
+ val e2 = ccFromMap.fromMap[MLParams](e, strict = true)
+ val s = ccFromMap.loadExperiment(n).getOrElse {
+ throw new Exception(s"Cannot find algorithm $n in the standard benchmark algorithms")
+ }
+ MLTest(s, e2)
}
var c = YamlConfig(runnableBenchmarks = e2)
- for (output <- m.get("output")) {
+ for (output <- m.get("output"))
c = c.copy(output = output.toString)
- }
- for (x <- m.get("timeoutSeconds")) {
+ for (x <- m.get("timeoutSeconds"))
c = c.copy(timeout = x.toString.toInt.seconds)
- }
c
}
- /**
- * Reads a file (assumed to contain a yaml config).
- */
- def readFile(filename: String): YamlConfig = {
+ /** Reads a file (assumed to contain a yaml config).
+ */
+ def readFile(filename: String): YamlConfig =
readString(Source.fromFile(filename).mkString)
- }
// Converts a java dictionary to a scala map.
- private def dict[T](d: T): Map[String, Any] = {
+ private def dict[T](d: T): Map[String, Any] =
d.asInstanceOf[java.util.Map[String, Any]].asScala.toMap
- }
- /**
- * Given keys that may be lists, builds the cartesian product of all the values into defined
- * options.
- *
- * For example: {a: [1,2], b: [3,4]} -> {a: 1, b: 3}, {a: 1, b:4}, {a:2, b:3}, ...
- *
- * @return
- */
- private def cartesian(m: Map[String, Any]): Seq[Map[String, Any]] = {
+ /** Given keys that may be lists, builds the cartesian product of all the values into defined
+ * options.
+ *
+ * For example: {a: [1,2], b: [3,4]} -> {a: 1, b: 3}, {a: 1, b:4}, {a:2, b:3}, ...
+ *
+ * @return
+ */
+ private def cartesian(m: Map[String, Any]): Seq[Map[String, Any]] =
if (m.isEmpty) {
Seq(m)
} else {
- val k = m.keys.head
+ val k = m.keys.head
val sub = m - k
- val l = cartesian(sub)
+ val l = cartesian(sub)
m(k) match {
case a: AL[_] =>
for {
- x <- a.asScala.toSeq
+ x <- a.asScala.toSeq
m2 <- l
- } yield {
- m2 ++ Map(k -> x.asInstanceOf[Any])
- }
+ } yield m2 ++ Map(k -> x.asInstanceOf[Any])
case _ =>
val v = m(k)
- l.map { m => m ++ Map(k -> v) }
+ l.map(m => m ++ Map(k -> v))
}
}
- }
}
@@ -115,35 +109,42 @@ object ccFromMap {
// Builds a case class from a map.
// (taken from stack overflow)
// if strict, will report an error if some unknown arguments are passed to the constructor
- def fromMap[T: TypeTag: ClassTag](m: Map[String,_], strict: Boolean) = {
+ def fromMap[T: TypeTag: ClassTag](m: Map[String, _], strict: Boolean) = {
scala.reflect.runtime.universe
- val rm = runtimeMirror(classTag[T].runtimeClass.getClassLoader)
- val classTest = typeOf[T].typeSymbol.asClass
- val classMirror = rm.reflectClass(classTest)
- val constructor = typeOf[T].declaration(nme.CONSTRUCTOR).asMethod
+ val rm = runtimeMirror(classTag[T].runtimeClass.getClassLoader)
+ val classTest = typeOf[T].typeSymbol.asClass
+ val classMirror = rm.reflectClass(classTest)
+ val constructor = typeOf[T].declaration(nme.CONSTRUCTOR).asMethod
val constructorMirror = classMirror.reflectConstructor(constructor)
val constructorArgNames = constructor.paramss.flatten.map(_.name.toString).toSet
- val extraElements = m.keySet -- constructorArgNames
+ val extraElements = m.keySet -- constructorArgNames
if (extraElements.nonEmpty) {
- throw new Exception(s"Found extra arguments when instantiating an object of " +
- s"class ${classTest.asClass.toString}:" +
- s" ${extraElements.toSeq.sorted}")
+ throw new Exception(
+ s"Found extra arguments when instantiating an object of " +
+ s"class ${classTest.asClass.toString}:" +
+ s" ${extraElements.toSeq.sorted}"
+ )
}
- val constructorArgs = constructor.paramss.flatten.map( (param: Symbol) => {
+ val constructorArgs = constructor.paramss.flatten.map { (param: Symbol) =>
val paramName = param.name.toString
- if(param.typeSignature <:< typeOf[Option[Long]])
+ if (param.typeSignature <:< typeOf[Option[Long]])
OptionImplicits.checkLong(m.get(paramName).asInstanceOf[Option[Long]])
- else if(param.typeSignature <:< typeOf[Option[Double]])
+ else if (param.typeSignature <:< typeOf[Option[Double]])
OptionImplicits.checkDouble(m.get(paramName).asInstanceOf[Option[Double]])
- else if(param.typeSignature <:< typeOf[Option[Any]])
+ else if (param.typeSignature <:< typeOf[Option[Any]])
m.get(paramName)
else
- m.get(paramName).getOrElse(throw new IllegalArgumentException("Map is missing required parameter named " + paramName))
- })
+ m.get(paramName)
+ .getOrElse(
+ throw new IllegalArgumentException(
+ "Map is missing required parameter named " + paramName
+ )
+ )
+ }
- val res = constructorMirror(constructorArgs:_*).asInstanceOf[T]
+ val res = constructorMirror(constructorArgs: _*).asInstanceOf[T]
res
}
@@ -152,7 +153,7 @@ object ccFromMap {
val rm = runtimeMirror(getClass.getClassLoader)
try {
val module = rm.staticModule("com.databricks.spark.sql.perf.mllib." + name)
- val obj = rm.reflectModule(module)
+ val obj = rm.reflectModule(module)
Success(obj.instance.asInstanceOf[BenchmarkAlgorithm])
} catch {
case x: scala.reflect.internal.MissingRequirementError =>
@@ -167,10 +168,10 @@ object ccFromMap {
def loadExperiment(
name: String,
- searchPackages: Seq[String] = defaultPackages): Option[BenchmarkAlgorithm] = {
+ searchPackages: Seq[String] = defaultPackages
+ ): Option[BenchmarkAlgorithm] =
searchPackages.view.flatMap { p =>
val n = if (p.isEmpty) name else s"$p.$name"
load(n).toOption
- } .headOption
- }
+ }.headOption
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/package.scala b/src/main/scala/com/databricks/spark/sql/perf/package.scala
index 080d0243..563b32d7 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/package.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/package.scala
@@ -4,5 +4,7 @@ import org.apache.spark.sql.functions._
package object perf {
val runtime =
- (col("result.analysisTime") + col("result.optimizationTime") + col("result.planningTime") + col("result.executionTime")).as("runtime")
-}
\ No newline at end of file
+ (col("result.analysisTime") + col("result.optimizationTime") + col("result.planningTime") + col(
+ "result.executionTime"
+ )).as("runtime")
+}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/results.scala b/src/main/scala/com/databricks/spark/sql/perf/results.scala
index 28d72263..9fcd7c91 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/results.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/results.scala
@@ -18,62 +18,84 @@ package com.databricks.spark.sql.perf
import com.databricks.spark.sql.perf.mllib.ReflectionUtils
-/**
- * The performance results of all given queries for a single iteration.
- *
- * @param timestamp The timestamp indicates when the entire experiment is started.
- * @param iteration The index number of the current iteration.
- * @param tags Tags of this iteration (variations are stored at here).
- * @param configuration Configuration properties of this iteration.
- * @param results The performance results of queries for this iteration.
- */
+/** The performance results of all given queries for a single iteration.
+ *
+ * @param timestamp
+ * The timestamp indicates when the entire experiment is started.
+ * @param iteration
+ * The index number of the current iteration.
+ * @param tags
+ * Tags of this iteration (variations are stored at here).
+ * @param configuration
+ * Configuration properties of this iteration.
+ * @param results
+ * The performance results of queries for this iteration.
+ */
case class ExperimentRun(
timestamp: Long,
iteration: Int,
tags: Map[String, String],
configuration: BenchmarkConfiguration,
- results: Seq[BenchmarkResult])
+ results: Seq[BenchmarkResult]
+)
-/**
- * The configuration used for an iteration of an experiment.
- *
- * @param sparkVersion The version of Spark.
- * @param sqlConf All configuration properties related to Spark SQL.
- * @param sparkConf All configuration properties of Spark.
- * @param defaultParallelism The default parallelism of the cluster.
- * Usually, it is the number of cores of the cluster.
- */
+/** The configuration used for an iteration of an experiment.
+ *
+ * @param sparkVersion
+ * The version of Spark.
+ * @param sqlConf
+ * All configuration properties related to Spark SQL.
+ * @param sparkConf
+ * All configuration properties of Spark.
+ * @param defaultParallelism
+ * The default parallelism of the cluster. Usually, it is the number of cores of the cluster.
+ */
case class BenchmarkConfiguration(
sparkVersion: String = org.apache.spark.SPARK_VERSION,
sqlConf: Map[String, String],
sparkConf: Map[String, String],
defaultParallelism: Int,
- buildInfo: Map[String, String])
+ buildInfo: Map[String, String]
+)
-/**
- * The result of a query.
- *
- * @param name The name of the query.
- * @param mode The ExecutionMode of this run.
- * @param parameters Additional parameters that describe this query.
- * @param joinTypes The type of join operations in the query.
- * @param tables The tables involved in the query.
- * @param parsingTime The time used to parse the query.
- * @param analysisTime The time used to analyze the query.
- * @param optimizationTime The time used to optimize the query.
- * @param planningTime The time used to plan the query.
- * @param executionTime The time used to execute the query.
- * @param result the result of this run. It is not necessarily the result of the query.
- * For example, it can be the number of rows generated by this query or
- * the sum of hash values of rows generated by this query.
- * @param breakDown The breakdown results of the query plan tree.
- * @param queryExecution The query execution plan.
- * @param failure The failure message.
- * @param mlResult The result metrics specific to MLlib.
- * @param benchmarkId An optional ID to identify a series of benchmark runs.
- * In ML, this is generated based on the benchmark name and
- * the hash value of params.
- */
+/** The result of a query.
+ *
+ * @param name
+ * The name of the query.
+ * @param mode
+ * The ExecutionMode of this run.
+ * @param parameters
+ * Additional parameters that describe this query.
+ * @param joinTypes
+ * The type of join operations in the query.
+ * @param tables
+ * The tables involved in the query.
+ * @param parsingTime
+ * The time used to parse the query.
+ * @param analysisTime
+ * The time used to analyze the query.
+ * @param optimizationTime
+ * The time used to optimize the query.
+ * @param planningTime
+ * The time used to plan the query.
+ * @param executionTime
+ * The time used to execute the query.
+ * @param result
+ * the result of this run. It is not necessarily the result of the query. For example, it can be
+ * the number of rows generated by this query or the sum of hash values of rows generated by this
+ * query.
+ * @param breakDown
+ * The breakdown results of the query plan tree.
+ * @param queryExecution
+ * The query execution plan.
+ * @param failure
+ * The failure message.
+ * @param mlResult
+ * The result metrics specific to MLlib.
+ * @param benchmarkId
+ * An optional ID to identify a series of benchmark runs. In ML, this is generated based on the
+ * benchmark name and the hash value of params.
+ */
case class BenchmarkResult(
name: String,
mode: String,
@@ -90,34 +112,37 @@ case class BenchmarkResult(
queryExecution: Option[String] = None,
failure: Option[Failure] = None,
mlResult: Option[Array[MLMetric]] = None,
- benchmarkId: Option[String] = None)
+ benchmarkId: Option[String] = None
+)
-/**
- * The execution time of a subtree of the query plan tree of a specific query.
- *
- * @param nodeName The name of the top physical operator of the subtree.
- * @param nodeNameWithArgs The name and arguments of the top physical operator of the subtree.
- * @param index The index of the top physical operator of the subtree
- * in the original query plan tree. The index starts from 0
- * (0 represents the top physical operator of the original query plan tree).
- * @param executionTime The execution time of the subtree.
- */
+/** The execution time of a subtree of the query plan tree of a specific query.
+ *
+ * @param nodeName
+ * The name of the top physical operator of the subtree.
+ * @param nodeNameWithArgs
+ * The name and arguments of the top physical operator of the subtree.
+ * @param index
+ * The index of the top physical operator of the subtree in the original query plan tree. The
+ * index starts from 0 (0 represents the top physical operator of the original query plan tree).
+ * @param executionTime
+ * The execution time of the subtree.
+ */
case class BreakdownResult(
nodeName: String,
nodeNameWithArgs: String,
index: Int,
children: Seq[Int],
executionTime: Double,
- delta: Double)
+ delta: Double
+)
case class Failure(className: String, message: String)
-/**
- * Class wrapping parameters for ML tests.
- *
- * KEEP CONSTRUCTOR ARGUMENTS SORTED BY NAME.
- * It simplifies lookup when checking if a parameter is here already.
- */
+/** Class wrapping parameters for ML tests.
+ *
+ * KEEP CONSTRUCTOR ARGUMENTS SORTED BY NAME. It simplifies lookup when checking if a parameter is
+ * here already.
+ */
class MLParams(
// *** Common to all algorithms ***
val randomSeed: Option[Int] = Some(42),
@@ -148,12 +173,12 @@ class MLParams(
val rank: Option[Int] = None,
val smoothing: Option[Double] = None,
val tol: Option[Double] = None,
- val vocabSize: Option[Int] = None) {
+ val vocabSize: Option[Int] = None
+) {
- /**
- * Returns a map of param names to string representations of their values. Only params that
- * were defined (i.e., not equal to None) are included in the map.
- */
+ /** Returns a map of param names to string representations of their values. Only params that were
+ * defined (i.e., not equal to None) are included in the map.
+ */
def toMap: Map[String, String] = {
// Only outputs params that have values
val allParams = ReflectionUtils.getConstructorArgs(this)
@@ -196,7 +221,8 @@ class MLParams(
rank: Option[Int] = rank,
smoothing: Option[Double] = smoothing,
tol: Option[Double] = tol,
- vocabSize: Option[Int] = vocabSize): MLParams = {
+ vocabSize: Option[Int] = vocabSize
+ ): MLParams =
new MLParams(
randomSeed = randomSeed,
numExamples = numExamples,
@@ -225,26 +251,25 @@ class MLParams(
rank = rank,
smoothing = smoothing,
tol = tol,
- vocabSize = vocabSize)
- }
+ vocabSize = vocabSize
+ )
}
object MLParams {
val empty = new MLParams()
}
-/**
- * Metrics specific to MLlib benchmark.
- *
- * @param metricName the name of the metric
- * @param metricValue the value of the metric
- * @param isLargerBetter the indicator showing whether larger metric value is better
- */
-case class MLMetric(
- metricName: String,
- metricValue: Double,
- isLargerBetter: Boolean)
+/** Metrics specific to MLlib benchmark.
+ *
+ * @param metricName
+ * the name of the metric
+ * @param metricValue
+ * the value of the metric
+ * @param isLargerBetter
+ * the indicator showing whether larger metric value is better
+ */
+case class MLMetric(metricName: String, metricValue: Double, isLargerBetter: Boolean)
object MLMetric {
val Invalid = MLMetric("Invalid", 0.0, false)
-}
\ No newline at end of file
+}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpcds/GenTPCDSData.scala b/src/main/scala/com/databricks/spark/sql/perf/tpcds/GenTPCDSData.scala
index d3414844..332158b8 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/tpcds/GenTPCDSData.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/tpcds/GenTPCDSData.scala
@@ -31,23 +31,22 @@ case class GenTPCDSDataConfig(
clusterByPartitionColumns: Boolean = true,
filterOutNullPartitionValues: Boolean = true,
tableFilter: String = "",
- numPartitions: Int = 100)
+ numPartitions: Int = 100
+)
-/**
- * Gen TPCDS data.
- * To run this:
- * {{{
- * build/sbt "test:runMain -d -s -l -f "
- * }}}
- */
+/** Gen TPCDS data. To run this:
+ * {{{
+ * build/sbt "test:runMain -d -s -l -f "
+ * }}}
+ */
object GenTPCDSData {
def main(args: Array[String]): Unit = {
val parser = new scopt.OptionParser[GenTPCDSDataConfig]("Gen-TPC-DS-data") {
opt[String]('m', "master")
- .action { (x, c) => c.copy(master = x) }
+ .action((x, c) => c.copy(master = x))
.text("the Spark master to use, default to local[*]")
opt[String]('d', "dsdgenDir")
- .action { (x, c) => c.copy(dsdgenDir = x) }
+ .action((x, c) => c.copy(dsdgenDir = x))
.text("location of dsdgen")
.required()
opt[String]('s', "scaleFactor")
@@ -58,7 +57,7 @@ object GenTPCDSData {
.text("root directory of location to create data in")
opt[String]('f', "format")
.action((x, c) => c.copy(format = x))
- .text("valid spark format, Parquet, ORC ...")
+ .text("valid spark format, Parquet, ORC, Delta, Iceberg ...")
opt[Boolean]('i', "useDoubleForDecimal")
.action((x, c) => c.copy(useDoubleForDecimal = x))
.text("true to replace DecimalType with DoubleType")
@@ -102,11 +101,13 @@ object GenTPCDSData {
.master(config.master)
.getOrCreate()
- val tables = new TPCDSTables(spark.sqlContext,
+ val tables = new TPCDSTables(
+ spark.sqlContext,
dsdgenDir = config.dsdgenDir,
scaleFactor = config.scaleFactor,
useDoubleForDecimal = config.useDoubleForDecimal,
- useStringForDate = config.useStringForDate)
+ useStringForDate = config.useStringForDate
+ )
tables.genData(
location = config.location,
@@ -116,6 +117,7 @@ object GenTPCDSData {
clusterByPartitionColumns = config.clusterByPartitionColumns,
filterOutNullPartitionValues = config.filterOutNullPartitionValues,
tableFilter = config.tableFilter,
- numPartitions = config.numPartitions)
+ numPartitions = config.numPartitions
+ )
}
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpcds/ImpalaKitQueries.scala b/src/main/scala/com/databricks/spark/sql/perf/tpcds/ImpalaKitQueries.scala
index 5ef20344..d431d1a4 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/tpcds/ImpalaKitQueries.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/tpcds/ImpalaKitQueries.scala
@@ -16,7 +16,7 @@
package com.databricks.spark.sql.perf.tpcds
-import com.databricks.spark.sql.perf.{ExecutionMode, Benchmark}
+import com.databricks.spark.sql.perf.{Benchmark, ExecutionMode}
trait ImpalaKitQueries extends Benchmark {
@@ -25,7 +25,9 @@ trait ImpalaKitQueries extends Benchmark {
// Queries are from
// https://github.com/cloudera/impala-tpcds-kit/tree/master/queries-sql92-modified/queries
val queries = Seq(
- ("q19", """
+ (
+ "q19",
+ """
|-- start query 1 in stream 0 using template query19.tpl
|select
| i_brand_id,
@@ -60,9 +62,11 @@ trait ImpalaKitQueries extends Benchmark {
| i_manufact
|limit 100
|-- end query 1 in stream 0 using template query19.tpl
- """.stripMargin),
-
- ("q27", """
+ """.stripMargin
+ ),
+ (
+ "q27",
+ """
|-- start query 1 in stream 0 using template query27.tpl
|select
| i_item_id,
@@ -95,9 +99,11 @@ trait ImpalaKitQueries extends Benchmark {
| s_state
|limit 100
|-- end query 1 in stream 0 using template query27.tpl
- """.stripMargin),
-
- ("q3", """
+ """.stripMargin
+ ),
+ (
+ "q3",
+ """
|-- start query 1 in stream 0 using template query3.tpl
|select
| dt.d_year,
@@ -139,9 +145,11 @@ trait ImpalaKitQueries extends Benchmark {
| brand_id
|-- end query 1 in stream 0 using template query3.tpl
|limit 100
- """.stripMargin),
-
- ("q34", """
+ """.stripMargin
+ ),
+ (
+ "q34",
+ """
|-- start query 1 in stream 0 using template query34.tpl
|select
| c_last_name,
@@ -186,9 +194,11 @@ trait ImpalaKitQueries extends Benchmark {
| cnt
|limit 1000
|-- end query 1 in stream 0 using template query34.tpl
- """.stripMargin),
-
- ("q42", """
+ """.stripMargin
+ ),
+ (
+ "q42",
+ """
|-- start query 1 in stream 0 using template query42.tpl
|select
| d_year,
@@ -217,9 +227,11 @@ trait ImpalaKitQueries extends Benchmark {
| i_category
|limit 100
|-- end query 1 in stream 0 using template query42.tpl
- """.stripMargin),
-
- ("q43", """
+ """.stripMargin
+ ),
+ (
+ "q43",
+ """
|-- start query 1 in stream 0 using template query43.tpl
|select
| s_store_name,
@@ -255,9 +267,11 @@ trait ImpalaKitQueries extends Benchmark {
| sat_sales
|limit 100
|-- end query 1 in stream 0 using template query43.tpl
- """.stripMargin),
-
- ("q46", """
+ """.stripMargin
+ ),
+ (
+ "q46",
+ """
|-- start query 1 in stream 0 using template query46.tpl
|select
| c_last_name,
@@ -333,9 +347,11 @@ trait ImpalaKitQueries extends Benchmark {
| ss_ticket_number
|limit 100
|-- end query 1 in stream 0 using template query46.tpl
- """.stripMargin),
-
- ("q52", """
+ """.stripMargin
+ ),
+ (
+ "q52",
+ """
|-- start query 1 in stream 0 using template query52.tpl
|select
| d_year,
@@ -362,9 +378,11 @@ trait ImpalaKitQueries extends Benchmark {
| i_brand_id
|limit 100
|-- end query 1 in stream 0 using template query52.tpl
- """.stripMargin),
-
- ("q53", """
+ """.stripMargin
+ ),
+ (
+ "q53",
+ """
|-- start query 1 in stream 0 using template query53.tpl
|select
| *
@@ -405,9 +423,11 @@ trait ImpalaKitQueries extends Benchmark {
| i_manufact_id
|limit 100
|-- end query 1 in stream 0 using template query53.tpl
- """.stripMargin),
-
- ("q55", """
+ """.stripMargin
+ ),
+ (
+ "q55",
+ """
|-- start query 1 in stream 0 using template query55.tpl
|select
| i_brand_id,
@@ -431,9 +451,11 @@ trait ImpalaKitQueries extends Benchmark {
| i_brand_id
|limit 100
|-- end query 1 in stream 0 using template query55.tpl
- """.stripMargin),
-
- ("q59", """
+ """.stripMargin
+ ),
+ (
+ "q59",
+ """
|-- start query 1 in stream 0 using template query59.tpl
|select
| s_store_name1,
@@ -531,9 +553,11 @@ trait ImpalaKitQueries extends Benchmark {
| d_week_seq1
|limit 100
|-- end query 1 in stream 0 using template query59.tpl
- """.stripMargin),
-
- ("q63", """
+ """.stripMargin
+ ),
+ (
+ "q63",
+ """
|-- start query 1 in stream 0 using template query63.tpl
|select
| *
@@ -574,9 +598,11 @@ trait ImpalaKitQueries extends Benchmark {
| sum_sales
|limit 100
|-- end query 1 in stream 0 using template query63.tpl
- """.stripMargin),
-
- ("q65", """
+ """.stripMargin
+ ),
+ (
+ "q65",
+ """
|--q65
|-- start query 1 in stream 0 using template query65.tpl
|select
@@ -634,9 +660,11 @@ trait ImpalaKitQueries extends Benchmark {
| i_item_desc
|limit 100
|-- end query 1 in stream 0 using template query65.tpl
- """.stripMargin),
-
- ("q68", """
+ """.stripMargin
+ ),
+ (
+ "q68",
+ """
|-- start query 1 in stream 0 using template query68.tpl
|select
| c_last_name,
@@ -693,9 +721,11 @@ trait ImpalaKitQueries extends Benchmark {
| ss_ticket_number
|limit 100
|-- end query 1 in stream 0 using template query68.tpl
- """.stripMargin),
-
- ("q7", """
+ """.stripMargin
+ ),
+ (
+ "q7",
+ """
|-- start query 1 in stream 0 using template query7.tpl
|select
| i_item_id,
@@ -724,9 +754,11 @@ trait ImpalaKitQueries extends Benchmark {
| i_item_id
|limit 100
|-- end query 1 in stream 0 using template query7.tpl
- """.stripMargin),
-
- ("q73", """
+ """.stripMargin
+ ),
+ (
+ "q73",
+ """
|-- start query 1 in stream 0 using template query73.tpl
|select
| c_last_name,
@@ -775,9 +807,11 @@ trait ImpalaKitQueries extends Benchmark {
| cnt desc
|limit 1000
|-- end query 1 in stream 0 using template query73.tpl
- """.stripMargin),
-
- ("q79", """
+ """.stripMargin
+ ),
+ (
+ "q79",
+ """
|-- start query 1 in stream 0 using template query79.tpl
|select
| c_last_name,
@@ -823,9 +857,11 @@ trait ImpalaKitQueries extends Benchmark {
| profit
|limit 100
|-- end query 1 in stream 0 using template query79.tpl
- """.stripMargin),
-
- ("q8", """
+ """.stripMargin
+ ),
+ (
+ "q8",
+ """
|-- start query 8 in stream 0 using template query8.tpl
|select s_store_name
| ,sum(ss_net_profit)
@@ -885,9 +921,11 @@ trait ImpalaKitQueries extends Benchmark {
| order by s_store_name
|limit 100
|-- end query 8 in stream 0 using template query8.tpl
- """.stripMargin),
-
- ("q82", """
+ """.stripMargin
+ ),
+ (
+ "q82",
+ """
|-- start query 1 in stream 0 using template query82.tpl
|select
| i_item_id,
@@ -912,9 +950,11 @@ trait ImpalaKitQueries extends Benchmark {
| i_item_id
|limit 100
|-- end query 1 in stream 0 using template query82.tpl
- """.stripMargin),
-
- ("q89", """
+ """.stripMargin
+ ),
+ (
+ "q89",
+ """
|-- start query 1 in stream 0 using template query89.tpl
|select
| *
@@ -958,9 +998,11 @@ trait ImpalaKitQueries extends Benchmark {
| s_store_name
|limit 100
|-- end query 1 in stream 0 using template query89.tpl
- """.stripMargin),
-
- ("q98", """
+ """.stripMargin
+ ),
+ (
+ "q98",
+ """
|-- start query 1 in stream 0 using template query98.tpl
|select
| i_item_desc,
@@ -995,9 +1037,11 @@ trait ImpalaKitQueries extends Benchmark {
| -- revenueratio
|limit 1000
|-- end query 1 in stream 0 using template query98.tpl
- """.stripMargin),
-
- ("ss_max", """
+ """.stripMargin
+ ),
+ (
+ "ss_max",
+ """
|select
| count(*) as total,
| count(ss_sold_date_sk) as not_null_total,
@@ -1012,14 +1056,17 @@ trait ImpalaKitQueries extends Benchmark {
| max(ss_store_sk) as max_ss_store_sk,
| max(ss_promo_sk) as max_ss_promo_sk
|from store_sales
- """.stripMargin)
+ """.stripMargin
+ )
).map {
case (name, sqlText) => Query(name, sqlText, description = "", executionMode = CollectResults)
}
val queriesMap = queries.map(q => q.name -> q).toMap
val originalQueries = Seq(
- ("q3", """
+ (
+ "q3",
+ """
select d_year
,item.i_brand_id brand_id
,item.i_brand brand
@@ -1036,9 +1083,11 @@ trait ImpalaKitQueries extends Benchmark {
order by d_year
,sum_agg desc
,brand_id
- limit 100"""),
-
- ("q7", """
+ limit 100"""
+ ),
+ (
+ "q7",
+ """
select i_item_id,
avg(ss_quantity) agg1,
avg(ss_list_price) agg2,
@@ -1057,9 +1106,11 @@ trait ImpalaKitQueries extends Benchmark {
d_year = 1998
group by i_item_id
order by i_item_id
- limit 100"""),
-
- ("q19", """
+ limit 100"""
+ ),
+ (
+ "q19",
+ """
select i_brand_id, i_brand, i_manufact_id, i_manufact,
sum(ss_ext_sales_price) as ext_price
from date_dim
@@ -1082,9 +1133,11 @@ trait ImpalaKitQueries extends Benchmark {
,i_brand_id
,i_manufact_id
,i_manufact
- limit 100"""),
-
- ("q27", """
+ limit 100"""
+ ),
+ (
+ "q27",
+ """
select i_item_id,
s_state,
avg(ss_quantity) agg1,
@@ -1105,9 +1158,11 @@ trait ImpalaKitQueries extends Benchmark {
group by i_item_id, s_state
order by i_item_id
,s_state
- limit 100"""),
-
- ("q34", """
+ limit 100"""
+ ),
+ (
+ "q34",
+ """
select c_last_name
,c_first_name
,c_salutation
@@ -1143,9 +1198,11 @@ trait ImpalaKitQueries extends Benchmark {
c_salutation,
c_preferred_cust_flag desc,
ss_ticket_number,
- cnt"""),
-
- ("q42", """
+ cnt"""
+ ),
+ (
+ "q42",
+ """
select d_year
,item.i_category_id
,item.i_category
@@ -1163,9 +1220,11 @@ trait ImpalaKitQueries extends Benchmark {
order by s desc,d_year
,i_category_id
,i_category
- limit 100"""),
-
- ("q43", """
+ limit 100"""
+ ),
+ (
+ "q43",
+ """
select s_store_name, s_store_id,
sum(case when (d_day_name='Sunday') then ss_sales_price else null end) sun_sales,
sum(case when (d_day_name='Monday') then ss_sales_price else null end) mon_sales,
@@ -1182,9 +1241,11 @@ trait ImpalaKitQueries extends Benchmark {
d_year = 1998
group by s_store_name, s_store_id
order by s_store_name, s_store_id,sun_sales,mon_sales,tue_sales,wed_sales,thu_sales,fri_sales,sat_sales
- limit 100"""),
-
- ("q46", """
+ limit 100"""
+ ),
+ (
+ "q46",
+ """
select c_last_name
,c_first_name
,ca_city
@@ -1218,9 +1279,11 @@ trait ImpalaKitQueries extends Benchmark {
,ca_city
,bought_city
,ss_ticket_number
- limit 100"""),
-
- ("q52", """
+ limit 100"""
+ ),
+ (
+ "q52",
+ """
select d_year
,item.i_brand_id brand_id
,item.i_brand brand
@@ -1238,9 +1301,11 @@ trait ImpalaKitQueries extends Benchmark {
order by d_year
,ext_price desc
,brand_id
- limit 100"""),
-
- ("q55", """
+ limit 100"""
+ ),
+ (
+ "q55",
+ """
select i_brand_id as brand_id, i_brand as brand,
sum(store_sales.ss_ext_sales_price) ext_price
from date_dim
@@ -1252,9 +1317,10 @@ trait ImpalaKitQueries extends Benchmark {
and d_year=2001
group by i_brand, i_brand_id
order by ext_price desc, brand_id
- limit 100 """),
-
- ("q59",
+ limit 100 """
+ ),
+ (
+ "q59",
"""
|select
| s_store_name1,
@@ -1355,9 +1421,11 @@ trait ImpalaKitQueries extends Benchmark {
| s_store_id1,
| d_week_seq1
|limit 100
- """.stripMargin),
-
- ("q68", """
+ """.stripMargin
+ ),
+ (
+ "q68",
+ """
select c_last_name ,c_first_name ,ca_city
,bought_city ,ss_ticket_number ,extended_price
,extended_tax ,list_price
@@ -1387,9 +1455,11 @@ trait ImpalaKitQueries extends Benchmark {
customer_address.ca_city <> dn.bought_city
order by c_last_name
,ss_ticket_number
- limit 100"""),
-
- ("q73", """
+ limit 100"""
+ ),
+ (
+ "q73",
+ """
select c_last_name
,c_first_name
,c_salutation
@@ -1416,9 +1486,11 @@ trait ImpalaKitQueries extends Benchmark {
JOIN customer ON dj.ss_customer_sk = customer.c_customer_sk
where
cnt between 5 and 10
- order by cnt desc"""),
-
- ("q79", """
+ order by cnt desc"""
+ ),
+ (
+ "q79",
+ """
select
c_last_name,c_first_name,substr(s_city,1,30) as s_city,ss_ticket_number,amt,profit
from
@@ -1439,9 +1511,10 @@ trait ImpalaKitQueries extends Benchmark {
group by ss_ticket_number,ss_customer_sk,ss_addr_sk,store.s_city) ms
JOIN customer on ms.ss_customer_sk = customer.c_customer_sk
order by c_last_name,c_first_name,s_city, profit
- limit 100"""),
-
- ("qSsMax",
+ limit 100"""
+ ),
+ (
+ "qSsMax",
"""
|select
| count(*) as total,
@@ -1457,14 +1530,16 @@ trait ImpalaKitQueries extends Benchmark {
| max(ss_store_sk) as max_ss_store_sk,
| max(ss_promo_sk) as max_ss_promo_sk
|from store_sales
- """.stripMargin)
- ).map { case (name, sqlText) =>
- Query(name, sqlText, description = "original query", executionMode = CollectResults)
+ """.stripMargin
+ )
+ ).map {
+ case (name, sqlText) =>
+ Query(name, sqlText, description = "original query", executionMode = CollectResults)
}
val interactiveQueries =
Seq("q19", "q42", "q52", "q55", "q63", "q68", "q73", "q98").map(queriesMap)
- val reportingQueries = Seq("q3","q7", "q27","q43", "q53", "q89").map(queriesMap)
- val deepAnalyticQueries = Seq("q34", "q46", "q59", "q65", "q79", "ss_max").map(queriesMap)
- val impalaKitQueries = interactiveQueries ++ reportingQueries ++ deepAnalyticQueries
+ val reportingQueries = Seq("q3", "q7", "q27", "q43", "q53", "q89").map(queriesMap)
+ val deepAnalyticQueries = Seq("q34", "q46", "q59", "q65", "q79", "ss_max").map(queriesMap)
+ val impalaKitQueries = interactiveQueries ++ reportingQueries ++ deepAnalyticQueries
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpcds/SimpleQueries.scala b/src/main/scala/com/databricks/spark/sql/perf/tpcds/SimpleQueries.scala
index 1f7f3554..1cb1644e 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/tpcds/SimpleQueries.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/tpcds/SimpleQueries.scala
@@ -16,32 +16,38 @@
package com.databricks.spark.sql.perf.tpcds
-import com.databricks.spark.sql.perf.{ExecutionMode, Benchmark}
+import com.databricks.spark.sql.perf.{Benchmark, ExecutionMode}
trait SimpleQueries extends Benchmark {
import ExecutionMode._
- val targetedPerfQueries = Seq(
- // Query to measure scan performance.
- ("stores-sales-scan",
- """
+ val targetedPerfQueries = Seq(
+ // Query to measure scan performance.
+ (
+ "stores-sales-scan",
+ """
|select * from store_sales where ss_item_sk = 1
- """.stripMargin),
- ("fact-fact-join",
- """
+ """.stripMargin
+ ),
+ (
+ "fact-fact-join",
+ """
| select count(*) from store_sales
| join store_returns
| on store_sales.ss_item_sk = store_returns.sr_item_sk
| and store_sales.ss_ticket_number = store_returns.sr_ticket_number
- """.stripMargin)
- ).map { case (name, sqlText) =>
- Query(name = name, sqlText = sqlText, description = "", executionMode = ForeachResults)
- }
+ """.stripMargin
+ )
+ ).map {
+ case (name, sqlText) =>
+ Query(name = name, sqlText = sqlText, description = "", executionMode = ForeachResults)
+ }
- val q7Derived = Seq(
- ("q7-simpleScan",
- """
+ val q7Derived = Seq(
+ (
+ "q7-simpleScan",
+ """
|select
| ss_quantity,
| ss_list_price,
@@ -54,9 +60,11 @@ trait SimpleQueries extends Benchmark {
|from store_sales
|where
| ss_sold_date_sk between 2450815 and 2451179
- """.stripMargin),
-
- ("q7-twoMapJoins", """
+ """.stripMargin
+ ),
+ (
+ "q7-twoMapJoins",
+ """
|select
| i_item_id,
| ss_quantity,
@@ -74,9 +82,11 @@ trait SimpleQueries extends Benchmark {
| and cd_marital_status = 'W'
| and cd_education_status = 'Primary'
| and ss_sold_date_sk between 2450815 and 2451179 -- partition key filter
- """.stripMargin),
-
- ("q7-fourMapJoins", """
+ """.stripMargin
+ ),
+ (
+ "q7-fourMapJoins",
+ """
|select
| i_item_id,
| ss_quantity,
@@ -98,9 +108,11 @@ trait SimpleQueries extends Benchmark {
| and d_year = 1998
| -- and ss_date between '1998-01-01' and '1998-12-31'
| and ss_sold_date_sk between 2450815 and 2451179 -- partition key filter
- """.stripMargin),
-
- ("q7-noOrderBy", """
+ """.stripMargin
+ ),
+ (
+ "q7-noOrderBy",
+ """
|select
| i_item_id,
| avg(ss_quantity) agg1,
@@ -124,9 +136,11 @@ trait SimpleQueries extends Benchmark {
| and ss_sold_date_sk between 2450815 and 2451179 -- partition key filter
|group by
| i_item_id
- """.stripMargin),
-
- ("q7", """
+ """.stripMargin
+ ),
+ (
+ "q7",
+ """
|-- start query 1 in stream 0 using template query7.tpl
|select
| i_item_id,
@@ -155,9 +169,11 @@ trait SimpleQueries extends Benchmark {
| i_item_id
|limit 100
|-- end query 1 in stream 0 using template query7.tpl
- """.stripMargin),
-
- ("store_sales-selfjoin-1", """
+ """.stripMargin
+ ),
+ (
+ "store_sales-selfjoin-1",
+ """
|-- The join condition will yield many matches.
|select
| t1.ss_quantity,
@@ -170,10 +186,11 @@ trait SimpleQueries extends Benchmark {
|from store_sales t1 join store_sales t2 on t1.ss_item_sk = t2.ss_item_sk
|where
| t1.ss_sold_date_sk between 2450815 and 2451179
- """.stripMargin),
-
-
- ("store_sales-selfjoin-2", """
+ """.stripMargin
+ ),
+ (
+ "store_sales-selfjoin-2",
+ """
|-- We ust comound primary key as the join condition. The size of output is comparable with the input table.
|select
| t1.ss_quantity,
@@ -186,8 +203,10 @@ trait SimpleQueries extends Benchmark {
|from store_sales t1 join store_sales t2 on t1.ss_item_sk = t2.ss_item_sk and t1.ss_ticket_number = t2.ss_ticket_number
|where
| t1.ss_sold_date_sk between 2450815 and 2451179
- """.stripMargin)
- ).map { case (name, sqlText) =>
- Query(name = name, sqlText = sqlText, description = "", executionMode = ForeachResults)
- }
+ """.stripMargin
+ )
+ ).map {
+ case (name, sqlText) =>
+ Query(name = name, sqlText = sqlText, description = "", executionMode = ForeachResults)
+ }
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDS.scala b/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDS.scala
index 2f173f0e..1fb69a01 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDS.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDS.scala
@@ -21,18 +21,18 @@ import com.databricks.spark.sql.perf._
import org.apache.spark.SparkContext
import org.apache.spark.sql.{SQLContext, SparkSession}
-/**
- * TPC-DS benchmark's dataset.
- *
- * @param sqlContext An existing SQLContext.
- */
+/** TPC-DS benchmark's dataset.
+ *
+ * @param sqlContext
+ * An existing SQLContext.
+ */
class TPCDS(@transient sqlContext: SQLContext)
- extends Benchmark(sqlContext)
- with ImpalaKitQueries
- with SimpleQueries
- with Tpcds_1_4_Queries
- with Tpcds_2_4_Queries
- with Serializable {
+ extends Benchmark(sqlContext)
+ with ImpalaKitQueries
+ with SimpleQueries
+ with Tpcds_1_4_Queries
+ with Tpcds_2_4_Queries
+ with Serializable {
def this() = this(SparkSession.builder.getOrCreate().sqlContext)
@@ -50,17 +50,16 @@ class TPCDS(@transient sqlContext: SQLContext)
println(setQuery)
sql(setQuery)
}
- */
-
- /**
- * Simple utilities to run the queries without persisting the results.
*/
+
+ /** Simple utilities to run the queries without persisting the results.
+ */
def explain(queries: Seq[Query], showPlan: Boolean = false): Unit = {
val succeeded = mutable.ArrayBuffer.empty[String]
queries.foreach { q =>
println(s"Query: ${q.name}")
try {
- val df = sqlContext.sql(q.sqlText.get)
+ val df = spark.sql(q.sqlText.get)
if (showPlan) {
df.explain()
} else {
@@ -80,28 +79,27 @@ class TPCDS(@transient sqlContext: SQLContext)
val succeeded = mutable.ArrayBuffer.empty[String]
queries.foreach { q =>
println(s"Query: ${q.name}")
- val start = System.currentTimeMillis()
- val df = sqlContext.sql(q.sqlText.get)
- var failed = false
+ val start = System.currentTimeMillis()
+ val df = spark.sql(q.sqlText.get)
+ var failed = false
val jobgroup = s"benchmark ${q.name}"
val t = new Thread("query runner") {
- override def run(): Unit = {
+ override def run(): Unit =
try {
- sqlContext.sparkContext.setJobGroup(jobgroup, jobgroup, true)
+ sparkContext.setJobGroup(jobgroup, jobgroup, true)
df.show(numRows)
} catch {
case e: Exception =>
println("Failed to run: " + e)
failed = true
}
- }
}
t.setDaemon(true)
t.start()
t.join(timeout)
if (t.isAlive) {
println(s"Timeout after $timeout seconds")
- sqlContext.sparkContext.cancelJobGroup(jobgroup)
+ sparkContext.cancelJobGroup(jobgroup)
t.interrupt()
} else {
if (!failed) {
@@ -115,6 +113,3 @@ class TPCDS(@transient sqlContext: SQLContext)
println(succeeded.map("\"" + _ + "\""))
}
}
-
-
-
diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDSTables.scala b/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDSTables.scala
index 8243cd34..7e404df6 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDSTables.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDSTables.scala
@@ -28,7 +28,7 @@ class DSDGEN(dsdgenDir: String) extends DataGenerator {
val dsdgen = s"$dsdgenDir/dsdgen"
def generate(sparkContext: SparkContext, name: String, partitions: Int, scaleFactor: String) = {
- val generatedData = {
+ val generatedData =
sparkContext.parallelize(1 to partitions, partitions).flatMap { i =>
val localToolsDir = if (new java.io.File(dsdgen).exists) {
dsdgenDir
@@ -41,502 +41,550 @@ class DSDGEN(dsdgenDir: String) extends DataGenerator {
// Note: RNGSEED is the RNG seed used by the data generator. Right now, it is fixed to 100.
val parallel = if (partitions > 1) s"-parallel $partitions -child $i" else ""
val commands = Seq(
- "bash", "-c",
- s"cd $localToolsDir && ./dsdgen -table $name -filter Y -scale $scaleFactor -RNGSEED 100 $parallel")
+ "bash",
+ "-c",
+ s"cd $localToolsDir && ./dsdgen -table $name -filter Y -scale $scaleFactor -RNGSEED 100 $parallel"
+ )
println(commands)
BlockingLineStream(commands)
}
- }
generatedData.setName(s"$name, sf=$scaleFactor, strings")
generatedData
}
}
-
class TPCDSTables(
- sqlContext: SQLContext,
- dsdgenDir: String,
- scaleFactor: String,
- useDoubleForDecimal: Boolean = false,
- useStringForDate: Boolean = false)
- extends Tables(sqlContext, scaleFactor, useDoubleForDecimal, useStringForDate) {
- import sqlContext.implicits._
+ sqlContext: SQLContext,
+ dsdgenDir: String,
+ scaleFactor: String,
+ useDoubleForDecimal: Boolean = false,
+ useStringForDate: Boolean = false
+) extends Tables(sqlContext, scaleFactor, useDoubleForDecimal, useStringForDate) {
+ import spark.implicits._
val dataGenerator = new DSDGEN(dsdgenDir)
val tables = Seq(
- Table("catalog_sales",
+ Table(
+ "catalog_sales",
partitionColumns = "cs_sold_date_sk" :: Nil,
- 'cs_sold_date_sk .int,
- 'cs_sold_time_sk .int,
- 'cs_ship_date_sk .int,
- 'cs_bill_customer_sk .int,
- 'cs_bill_cdemo_sk .int,
- 'cs_bill_hdemo_sk .int,
- 'cs_bill_addr_sk .int,
- 'cs_ship_customer_sk .int,
- 'cs_ship_cdemo_sk .int,
- 'cs_ship_hdemo_sk .int,
- 'cs_ship_addr_sk .int,
- 'cs_call_center_sk .int,
- 'cs_catalog_page_sk .int,
- 'cs_ship_mode_sk .int,
- 'cs_warehouse_sk .int,
- 'cs_item_sk .int,
- 'cs_promo_sk .int,
- 'cs_order_number .long,
- 'cs_quantity .int,
- 'cs_wholesale_cost .decimal(7,2),
- 'cs_list_price .decimal(7,2),
- 'cs_sales_price .decimal(7,2),
- 'cs_ext_discount_amt .decimal(7,2),
- 'cs_ext_sales_price .decimal(7,2),
- 'cs_ext_wholesale_cost .decimal(7,2),
- 'cs_ext_list_price .decimal(7,2),
- 'cs_ext_tax .decimal(7,2),
- 'cs_coupon_amt .decimal(7,2),
- 'cs_ext_ship_cost .decimal(7,2),
- 'cs_net_paid .decimal(7,2),
- 'cs_net_paid_inc_tax .decimal(7,2),
- 'cs_net_paid_inc_ship .decimal(7,2),
- 'cs_net_paid_inc_ship_tax .decimal(7,2),
- 'cs_net_profit .decimal(7,2)),
- Table("catalog_returns",
+ 'cs_sold_date_sk.int,
+ 'cs_sold_time_sk.int,
+ 'cs_ship_date_sk.int,
+ 'cs_bill_customer_sk.int,
+ 'cs_bill_cdemo_sk.int,
+ 'cs_bill_hdemo_sk.int,
+ 'cs_bill_addr_sk.int,
+ 'cs_ship_customer_sk.int,
+ 'cs_ship_cdemo_sk.int,
+ 'cs_ship_hdemo_sk.int,
+ 'cs_ship_addr_sk.int,
+ 'cs_call_center_sk.int,
+ 'cs_catalog_page_sk.int,
+ 'cs_ship_mode_sk.int,
+ 'cs_warehouse_sk.int,
+ 'cs_item_sk.int,
+ 'cs_promo_sk.int,
+ 'cs_order_number.long,
+ 'cs_quantity.int,
+ 'cs_wholesale_cost.decimal(7, 2),
+ 'cs_list_price.decimal(7, 2),
+ 'cs_sales_price.decimal(7, 2),
+ 'cs_ext_discount_amt.decimal(7, 2),
+ 'cs_ext_sales_price.decimal(7, 2),
+ 'cs_ext_wholesale_cost.decimal(7, 2),
+ 'cs_ext_list_price.decimal(7, 2),
+ 'cs_ext_tax.decimal(7, 2),
+ 'cs_coupon_amt.decimal(7, 2),
+ 'cs_ext_ship_cost.decimal(7, 2),
+ 'cs_net_paid.decimal(7, 2),
+ 'cs_net_paid_inc_tax.decimal(7, 2),
+ 'cs_net_paid_inc_ship.decimal(7, 2),
+ 'cs_net_paid_inc_ship_tax.decimal(7, 2),
+ 'cs_net_profit.decimal(7, 2)
+ ),
+ Table(
+ "catalog_returns",
partitionColumns = "cr_returned_date_sk" :: Nil,
- 'cr_returned_date_sk .int,
- 'cr_returned_time_sk .int,
- 'cr_item_sk .int,
- 'cr_refunded_customer_sk .int,
- 'cr_refunded_cdemo_sk .int,
- 'cr_refunded_hdemo_sk .int,
- 'cr_refunded_addr_sk .int,
- 'cr_returning_customer_sk .int,
- 'cr_returning_cdemo_sk .int,
- 'cr_returning_hdemo_sk .int,
- 'cr_returning_addr_sk .int,
- 'cr_call_center_sk .int,
- 'cr_catalog_page_sk .int,
- 'cr_ship_mode_sk .int,
- 'cr_warehouse_sk .int,
- 'cr_reason_sk .int,
- 'cr_order_number .long,
- 'cr_return_quantity .int,
- 'cr_return_amount .decimal(7,2),
- 'cr_return_tax .decimal(7,2),
- 'cr_return_amt_inc_tax .decimal(7,2),
- 'cr_fee .decimal(7,2),
- 'cr_return_ship_cost .decimal(7,2),
- 'cr_refunded_cash .decimal(7,2),
- 'cr_reversed_charge .decimal(7,2),
- 'cr_store_credit .decimal(7,2),
- 'cr_net_loss .decimal(7,2)),
- Table("inventory",
+ 'cr_returned_date_sk.int,
+ 'cr_returned_time_sk.int,
+ 'cr_item_sk.int,
+ 'cr_refunded_customer_sk.int,
+ 'cr_refunded_cdemo_sk.int,
+ 'cr_refunded_hdemo_sk.int,
+ 'cr_refunded_addr_sk.int,
+ 'cr_returning_customer_sk.int,
+ 'cr_returning_cdemo_sk.int,
+ 'cr_returning_hdemo_sk.int,
+ 'cr_returning_addr_sk.int,
+ 'cr_call_center_sk.int,
+ 'cr_catalog_page_sk.int,
+ 'cr_ship_mode_sk.int,
+ 'cr_warehouse_sk.int,
+ 'cr_reason_sk.int,
+ 'cr_order_number.long,
+ 'cr_return_quantity.int,
+ 'cr_return_amount.decimal(7, 2),
+ 'cr_return_tax.decimal(7, 2),
+ 'cr_return_amt_inc_tax.decimal(7, 2),
+ 'cr_fee.decimal(7, 2),
+ 'cr_return_ship_cost.decimal(7, 2),
+ 'cr_refunded_cash.decimal(7, 2),
+ 'cr_reversed_charge.decimal(7, 2),
+ 'cr_store_credit.decimal(7, 2),
+ 'cr_net_loss.decimal(7, 2)
+ ),
+ Table(
+ "inventory",
partitionColumns = "inv_date_sk" :: Nil,
- 'inv_date_sk .int,
- 'inv_item_sk .int,
- 'inv_warehouse_sk .int,
- 'inv_quantity_on_hand .int),
- Table("store_sales",
+ 'inv_date_sk.int,
+ 'inv_item_sk.int,
+ 'inv_warehouse_sk.int,
+ 'inv_quantity_on_hand.int
+ ),
+ Table(
+ "store_sales",
partitionColumns = "ss_sold_date_sk" :: Nil,
- 'ss_sold_date_sk .int,
- 'ss_sold_time_sk .int,
- 'ss_item_sk .int,
- 'ss_customer_sk .int,
- 'ss_cdemo_sk .int,
- 'ss_hdemo_sk .int,
- 'ss_addr_sk .int,
- 'ss_store_sk .int,
- 'ss_promo_sk .int,
- 'ss_ticket_number .long,
- 'ss_quantity .int,
- 'ss_wholesale_cost .decimal(7,2),
- 'ss_list_price .decimal(7,2),
- 'ss_sales_price .decimal(7,2),
- 'ss_ext_discount_amt .decimal(7,2),
- 'ss_ext_sales_price .decimal(7,2),
- 'ss_ext_wholesale_cost.decimal(7,2),
- 'ss_ext_list_price .decimal(7,2),
- 'ss_ext_tax .decimal(7,2),
- 'ss_coupon_amt .decimal(7,2),
- 'ss_net_paid .decimal(7,2),
- 'ss_net_paid_inc_tax .decimal(7,2),
- 'ss_net_profit .decimal(7,2)),
- Table("store_returns",
- partitionColumns = "sr_returned_date_sk" ::Nil,
- 'sr_returned_date_sk .int,
- 'sr_return_time_sk .int,
- 'sr_item_sk .int,
- 'sr_customer_sk .int,
- 'sr_cdemo_sk .int,
- 'sr_hdemo_sk .int,
- 'sr_addr_sk .int,
- 'sr_store_sk .int,
- 'sr_reason_sk .int,
- 'sr_ticket_number .long,
- 'sr_return_quantity .int,
- 'sr_return_amt .decimal(7,2),
- 'sr_return_tax .decimal(7,2),
- 'sr_return_amt_inc_tax.decimal(7,2),
- 'sr_fee .decimal(7,2),
- 'sr_return_ship_cost .decimal(7,2),
- 'sr_refunded_cash .decimal(7,2),
- 'sr_reversed_charge .decimal(7,2),
- 'sr_store_credit .decimal(7,2),
- 'sr_net_loss .decimal(7,2)),
- Table("web_sales",
+ 'ss_sold_date_sk.int,
+ 'ss_sold_time_sk.int,
+ 'ss_item_sk.int,
+ 'ss_customer_sk.int,
+ 'ss_cdemo_sk.int,
+ 'ss_hdemo_sk.int,
+ 'ss_addr_sk.int,
+ 'ss_store_sk.int,
+ 'ss_promo_sk.int,
+ 'ss_ticket_number.long,
+ 'ss_quantity.int,
+ 'ss_wholesale_cost.decimal(7, 2),
+ 'ss_list_price.decimal(7, 2),
+ 'ss_sales_price.decimal(7, 2),
+ 'ss_ext_discount_amt.decimal(7, 2),
+ 'ss_ext_sales_price.decimal(7, 2),
+ 'ss_ext_wholesale_cost.decimal(7, 2),
+ 'ss_ext_list_price.decimal(7, 2),
+ 'ss_ext_tax.decimal(7, 2),
+ 'ss_coupon_amt.decimal(7, 2),
+ 'ss_net_paid.decimal(7, 2),
+ 'ss_net_paid_inc_tax.decimal(7, 2),
+ 'ss_net_profit.decimal(7, 2)
+ ),
+ Table(
+ "store_returns",
+ partitionColumns = "sr_returned_date_sk" :: Nil,
+ 'sr_returned_date_sk.int,
+ 'sr_return_time_sk.int,
+ 'sr_item_sk.int,
+ 'sr_customer_sk.int,
+ 'sr_cdemo_sk.int,
+ 'sr_hdemo_sk.int,
+ 'sr_addr_sk.int,
+ 'sr_store_sk.int,
+ 'sr_reason_sk.int,
+ 'sr_ticket_number.long,
+ 'sr_return_quantity.int,
+ 'sr_return_amt.decimal(7, 2),
+ 'sr_return_tax.decimal(7, 2),
+ 'sr_return_amt_inc_tax.decimal(7, 2),
+ 'sr_fee.decimal(7, 2),
+ 'sr_return_ship_cost.decimal(7, 2),
+ 'sr_refunded_cash.decimal(7, 2),
+ 'sr_reversed_charge.decimal(7, 2),
+ 'sr_store_credit.decimal(7, 2),
+ 'sr_net_loss.decimal(7, 2)
+ ),
+ Table(
+ "web_sales",
partitionColumns = "ws_sold_date_sk" :: Nil,
- 'ws_sold_date_sk .int,
- 'ws_sold_time_sk .int,
- 'ws_ship_date_sk .int,
- 'ws_item_sk .int,
- 'ws_bill_customer_sk .int,
- 'ws_bill_cdemo_sk .int,
- 'ws_bill_hdemo_sk .int,
- 'ws_bill_addr_sk .int,
- 'ws_ship_customer_sk .int,
- 'ws_ship_cdemo_sk .int,
- 'ws_ship_hdemo_sk .int,
- 'ws_ship_addr_sk .int,
- 'ws_web_page_sk .int,
- 'ws_web_site_sk .int,
- 'ws_ship_mode_sk .int,
- 'ws_warehouse_sk .int,
- 'ws_promo_sk .int,
- 'ws_order_number .long,
- 'ws_quantity .int,
- 'ws_wholesale_cost .decimal(7,2),
- 'ws_list_price .decimal(7,2),
- 'ws_sales_price .decimal(7,2),
- 'ws_ext_discount_amt .decimal(7,2),
- 'ws_ext_sales_price .decimal(7,2),
- 'ws_ext_wholesale_cost .decimal(7,2),
- 'ws_ext_list_price .decimal(7,2),
- 'ws_ext_tax .decimal(7,2),
- 'ws_coupon_amt .decimal(7,2),
- 'ws_ext_ship_cost .decimal(7,2),
- 'ws_net_paid .decimal(7,2),
- 'ws_net_paid_inc_tax .decimal(7,2),
- 'ws_net_paid_inc_ship .decimal(7,2),
- 'ws_net_paid_inc_ship_tax .decimal(7,2),
- 'ws_net_profit .decimal(7,2)),
- Table("web_returns",
- partitionColumns = "wr_returned_date_sk" ::Nil,
- 'wr_returned_date_sk .int,
- 'wr_returned_time_sk .int,
- 'wr_item_sk .int,
- 'wr_refunded_customer_sk .int,
- 'wr_refunded_cdemo_sk .int,
- 'wr_refunded_hdemo_sk .int,
- 'wr_refunded_addr_sk .int,
- 'wr_returning_customer_sk .int,
- 'wr_returning_cdemo_sk .int,
- 'wr_returning_hdemo_sk .int,
- 'wr_returning_addr_sk .int,
- 'wr_web_page_sk .int,
- 'wr_reason_sk .int,
- 'wr_order_number .long,
- 'wr_return_quantity .int,
- 'wr_return_amt .decimal(7,2),
- 'wr_return_tax .decimal(7,2),
- 'wr_return_amt_inc_tax .decimal(7,2),
- 'wr_fee .decimal(7,2),
- 'wr_return_ship_cost .decimal(7,2),
- 'wr_refunded_cash .decimal(7,2),
- 'wr_reversed_charge .decimal(7,2),
- 'wr_account_credit .decimal(7,2),
- 'wr_net_loss .decimal(7,2)),
- Table("call_center",
+ 'ws_sold_date_sk.int,
+ 'ws_sold_time_sk.int,
+ 'ws_ship_date_sk.int,
+ 'ws_item_sk.int,
+ 'ws_bill_customer_sk.int,
+ 'ws_bill_cdemo_sk.int,
+ 'ws_bill_hdemo_sk.int,
+ 'ws_bill_addr_sk.int,
+ 'ws_ship_customer_sk.int,
+ 'ws_ship_cdemo_sk.int,
+ 'ws_ship_hdemo_sk.int,
+ 'ws_ship_addr_sk.int,
+ 'ws_web_page_sk.int,
+ 'ws_web_site_sk.int,
+ 'ws_ship_mode_sk.int,
+ 'ws_warehouse_sk.int,
+ 'ws_promo_sk.int,
+ 'ws_order_number.long,
+ 'ws_quantity.int,
+ 'ws_wholesale_cost.decimal(7, 2),
+ 'ws_list_price.decimal(7, 2),
+ 'ws_sales_price.decimal(7, 2),
+ 'ws_ext_discount_amt.decimal(7, 2),
+ 'ws_ext_sales_price.decimal(7, 2),
+ 'ws_ext_wholesale_cost.decimal(7, 2),
+ 'ws_ext_list_price.decimal(7, 2),
+ 'ws_ext_tax.decimal(7, 2),
+ 'ws_coupon_amt.decimal(7, 2),
+ 'ws_ext_ship_cost.decimal(7, 2),
+ 'ws_net_paid.decimal(7, 2),
+ 'ws_net_paid_inc_tax.decimal(7, 2),
+ 'ws_net_paid_inc_ship.decimal(7, 2),
+ 'ws_net_paid_inc_ship_tax.decimal(7, 2),
+ 'ws_net_profit.decimal(7, 2)
+ ),
+ Table(
+ "web_returns",
+ partitionColumns = "wr_returned_date_sk" :: Nil,
+ 'wr_returned_date_sk.int,
+ 'wr_returned_time_sk.int,
+ 'wr_item_sk.int,
+ 'wr_refunded_customer_sk.int,
+ 'wr_refunded_cdemo_sk.int,
+ 'wr_refunded_hdemo_sk.int,
+ 'wr_refunded_addr_sk.int,
+ 'wr_returning_customer_sk.int,
+ 'wr_returning_cdemo_sk.int,
+ 'wr_returning_hdemo_sk.int,
+ 'wr_returning_addr_sk.int,
+ 'wr_web_page_sk.int,
+ 'wr_reason_sk.int,
+ 'wr_order_number.long,
+ 'wr_return_quantity.int,
+ 'wr_return_amt.decimal(7, 2),
+ 'wr_return_tax.decimal(7, 2),
+ 'wr_return_amt_inc_tax.decimal(7, 2),
+ 'wr_fee.decimal(7, 2),
+ 'wr_return_ship_cost.decimal(7, 2),
+ 'wr_refunded_cash.decimal(7, 2),
+ 'wr_reversed_charge.decimal(7, 2),
+ 'wr_account_credit.decimal(7, 2),
+ 'wr_net_loss.decimal(7, 2)
+ ),
+ Table(
+ "call_center",
partitionColumns = Nil,
- 'cc_call_center_sk .int,
- 'cc_call_center_id .string,
- 'cc_rec_start_date .date,
- 'cc_rec_end_date .date,
- 'cc_closed_date_sk .int,
- 'cc_open_date_sk .int,
- 'cc_name .string,
- 'cc_class .string,
- 'cc_employees .int,
- 'cc_sq_ft .int,
- 'cc_hours .string,
- 'cc_manager .string,
- 'cc_mkt_id .int,
- 'cc_mkt_class .string,
- 'cc_mkt_desc .string,
- 'cc_market_manager .string,
- 'cc_division .int,
- 'cc_division_name .string,
- 'cc_company .int,
- 'cc_company_name .string,
- 'cc_street_number .string,
- 'cc_street_name .string,
- 'cc_street_type .string,
- 'cc_suite_number .string,
- 'cc_city .string,
- 'cc_county .string,
- 'cc_state .string,
- 'cc_zip .string,
- 'cc_country .string,
- 'cc_gmt_offset .decimal(5,2),
- 'cc_tax_percentage .decimal(5,2)),
- Table("catalog_page",
+ 'cc_call_center_sk.int,
+ 'cc_call_center_id.string,
+ 'cc_rec_start_date.date,
+ 'cc_rec_end_date.date,
+ 'cc_closed_date_sk.int,
+ 'cc_open_date_sk.int,
+ 'cc_name.string,
+ 'cc_class.string,
+ 'cc_employees.int,
+ 'cc_sq_ft.int,
+ 'cc_hours.string,
+ 'cc_manager.string,
+ 'cc_mkt_id.int,
+ 'cc_mkt_class.string,
+ 'cc_mkt_desc.string,
+ 'cc_market_manager.string,
+ 'cc_division.int,
+ 'cc_division_name.string,
+ 'cc_company.int,
+ 'cc_company_name.string,
+ 'cc_street_number.string,
+ 'cc_street_name.string,
+ 'cc_street_type.string,
+ 'cc_suite_number.string,
+ 'cc_city.string,
+ 'cc_county.string,
+ 'cc_state.string,
+ 'cc_zip.string,
+ 'cc_country.string,
+ 'cc_gmt_offset.decimal(5, 2),
+ 'cc_tax_percentage.decimal(5, 2)
+ ),
+ Table(
+ "catalog_page",
partitionColumns = Nil,
- 'cp_catalog_page_sk .int,
- 'cp_catalog_page_id .string,
- 'cp_start_date_sk .int,
- 'cp_end_date_sk .int,
- 'cp_department .string,
- 'cp_catalog_number .int,
- 'cp_catalog_page_number .int,
- 'cp_description .string,
- 'cp_type .string),
- Table("customer",
+ 'cp_catalog_page_sk.int,
+ 'cp_catalog_page_id.string,
+ 'cp_start_date_sk.int,
+ 'cp_end_date_sk.int,
+ 'cp_department.string,
+ 'cp_catalog_number.int,
+ 'cp_catalog_page_number.int,
+ 'cp_description.string,
+ 'cp_type.string
+ ),
+ Table(
+ "customer",
partitionColumns = Nil,
- 'c_customer_sk .int,
- 'c_customer_id .string,
- 'c_current_cdemo_sk .int,
- 'c_current_hdemo_sk .int,
- 'c_current_addr_sk .int,
- 'c_first_shipto_date_sk .int,
- 'c_first_sales_date_sk .int,
- 'c_salutation .string,
- 'c_first_name .string,
- 'c_last_name .string,
- 'c_preferred_cust_flag .string,
- 'c_birth_day .int,
- 'c_birth_month .int,
- 'c_birth_year .int,
- 'c_birth_country .string,
- 'c_login .string,
- 'c_email_address .string,
- 'c_last_review_date .string),
- Table("customer_address",
+ 'c_customer_sk.int,
+ 'c_customer_id.string,
+ 'c_current_cdemo_sk.int,
+ 'c_current_hdemo_sk.int,
+ 'c_current_addr_sk.int,
+ 'c_first_shipto_date_sk.int,
+ 'c_first_sales_date_sk.int,
+ 'c_salutation.string,
+ 'c_first_name.string,
+ 'c_last_name.string,
+ 'c_preferred_cust_flag.string,
+ 'c_birth_day.int,
+ 'c_birth_month.int,
+ 'c_birth_year.int,
+ 'c_birth_country.string,
+ 'c_login.string,
+ 'c_email_address.string,
+ 'c_last_review_date.string
+ ),
+ Table(
+ "customer_address",
partitionColumns = Nil,
- 'ca_address_sk .int,
- 'ca_address_id .string,
- 'ca_street_number .string,
- 'ca_street_name .string,
- 'ca_street_type .string,
- 'ca_suite_number .string,
- 'ca_city .string,
- 'ca_county .string,
- 'ca_state .string,
- 'ca_zip .string,
- 'ca_country .string,
- 'ca_gmt_offset .decimal(5,2),
- 'ca_location_type .string),
- Table("customer_demographics",
+ 'ca_address_sk.int,
+ 'ca_address_id.string,
+ 'ca_street_number.string,
+ 'ca_street_name.string,
+ 'ca_street_type.string,
+ 'ca_suite_number.string,
+ 'ca_city.string,
+ 'ca_county.string,
+ 'ca_state.string,
+ 'ca_zip.string,
+ 'ca_country.string,
+ 'ca_gmt_offset.decimal(5, 2),
+ 'ca_location_type.string
+ ),
+ Table(
+ "customer_demographics",
partitionColumns = Nil,
- 'cd_demo_sk .int,
- 'cd_gender .string,
- 'cd_marital_status .string,
- 'cd_education_status .string,
- 'cd_purchase_estimate .int,
- 'cd_credit_rating .string,
- 'cd_dep_count .int,
- 'cd_dep_employed_count .int,
- 'cd_dep_college_count .int),
- Table("date_dim",
+ 'cd_demo_sk.int,
+ 'cd_gender.string,
+ 'cd_marital_status.string,
+ 'cd_education_status.string,
+ 'cd_purchase_estimate.int,
+ 'cd_credit_rating.string,
+ 'cd_dep_count.int,
+ 'cd_dep_employed_count.int,
+ 'cd_dep_college_count.int
+ ),
+ Table(
+ "date_dim",
partitionColumns = Nil,
- 'd_date_sk .int,
- 'd_date_id .string,
- 'd_date .date,
- 'd_month_seq .int,
- 'd_week_seq .int,
- 'd_quarter_seq .int,
- 'd_year .int,
- 'd_dow .int,
- 'd_moy .int,
- 'd_dom .int,
- 'd_qoy .int,
- 'd_fy_year .int,
- 'd_fy_quarter_seq .int,
- 'd_fy_week_seq .int,
- 'd_day_name .string,
- 'd_quarter_name .string,
- 'd_holiday .string,
- 'd_weekend .string,
- 'd_following_holiday .string,
- 'd_first_dom .int,
- 'd_last_dom .int,
- 'd_same_day_ly .int,
- 'd_same_day_lq .int,
- 'd_current_day .string,
- 'd_current_week .string,
- 'd_current_month .string,
- 'd_current_quarter .string,
- 'd_current_year .string),
- Table("household_demographics",
+ 'd_date_sk.int,
+ 'd_date_id.string,
+ 'd_date.date,
+ 'd_month_seq.int,
+ 'd_week_seq.int,
+ 'd_quarter_seq.int,
+ 'd_year.int,
+ 'd_dow.int,
+ 'd_moy.int,
+ 'd_dom.int,
+ 'd_qoy.int,
+ 'd_fy_year.int,
+ 'd_fy_quarter_seq.int,
+ 'd_fy_week_seq.int,
+ 'd_day_name.string,
+ 'd_quarter_name.string,
+ 'd_holiday.string,
+ 'd_weekend.string,
+ 'd_following_holiday.string,
+ 'd_first_dom.int,
+ 'd_last_dom.int,
+ 'd_same_day_ly.int,
+ 'd_same_day_lq.int,
+ 'd_current_day.string,
+ 'd_current_week.string,
+ 'd_current_month.string,
+ 'd_current_quarter.string,
+ 'd_current_year.string
+ ),
+ Table(
+ "household_demographics",
partitionColumns = Nil,
- 'hd_demo_sk .int,
- 'hd_income_band_sk .int,
- 'hd_buy_potential .string,
- 'hd_dep_count .int,
- 'hd_vehicle_count .int),
- Table("income_band",
+ 'hd_demo_sk.int,
+ 'hd_income_band_sk.int,
+ 'hd_buy_potential.string,
+ 'hd_dep_count.int,
+ 'hd_vehicle_count.int
+ ),
+ Table(
+ "income_band",
partitionColumns = Nil,
- 'ib_income_band_sk .int,
- 'ib_lower_bound .int,
- 'ib_upper_bound .int),
- Table("item",
+ 'ib_income_band_sk.int,
+ 'ib_lower_bound.int,
+ 'ib_upper_bound.int
+ ),
+ Table(
+ "item",
partitionColumns = Nil,
- 'i_item_sk .int,
- 'i_item_id .string,
- 'i_rec_start_date .date,
- 'i_rec_end_date .date,
- 'i_item_desc .string,
- 'i_current_price .decimal(7,2),
- 'i_wholesale_cost .decimal(7,2),
- 'i_brand_id .int,
- 'i_brand .string,
- 'i_class_id .int,
- 'i_class .string,
- 'i_category_id .int,
- 'i_category .string,
- 'i_manufact_id .int,
- 'i_manufact .string,
- 'i_size .string,
- 'i_formulation .string,
- 'i_color .string,
- 'i_units .string,
- 'i_container .string,
- 'i_manager_id .int,
- 'i_product_name .string),
- Table("promotion",
+ 'i_item_sk.int,
+ 'i_item_id.string,
+ 'i_rec_start_date.date,
+ 'i_rec_end_date.date,
+ 'i_item_desc.string,
+ 'i_current_price.decimal(7, 2),
+ 'i_wholesale_cost.decimal(7, 2),
+ 'i_brand_id.int,
+ 'i_brand.string,
+ 'i_class_id.int,
+ 'i_class.string,
+ 'i_category_id.int,
+ 'i_category.string,
+ 'i_manufact_id.int,
+ 'i_manufact.string,
+ 'i_size.string,
+ 'i_formulation.string,
+ 'i_color.string,
+ 'i_units.string,
+ 'i_container.string,
+ 'i_manager_id.int,
+ 'i_product_name.string
+ ),
+ Table(
+ "promotion",
partitionColumns = Nil,
- 'p_promo_sk .int,
- 'p_promo_id .string,
- 'p_start_date_sk .int,
- 'p_end_date_sk .int,
- 'p_item_sk .int,
- 'p_cost .decimal(15,2),
- 'p_response_target .int,
- 'p_promo_name .string,
- 'p_channel_dmail .string,
- 'p_channel_email .string,
- 'p_channel_catalog .string,
- 'p_channel_tv .string,
- 'p_channel_radio .string,
- 'p_channel_press .string,
- 'p_channel_event .string,
- 'p_channel_demo .string,
- 'p_channel_details .string,
- 'p_purpose .string,
- 'p_discount_active .string),
- Table("reason",
+ 'p_promo_sk.int,
+ 'p_promo_id.string,
+ 'p_start_date_sk.int,
+ 'p_end_date_sk.int,
+ 'p_item_sk.int,
+ 'p_cost.decimal(15, 2),
+ 'p_response_target.int,
+ 'p_promo_name.string,
+ 'p_channel_dmail.string,
+ 'p_channel_email.string,
+ 'p_channel_catalog.string,
+ 'p_channel_tv.string,
+ 'p_channel_radio.string,
+ 'p_channel_press.string,
+ 'p_channel_event.string,
+ 'p_channel_demo.string,
+ 'p_channel_details.string,
+ 'p_purpose.string,
+ 'p_discount_active.string
+ ),
+ Table(
+ "reason",
partitionColumns = Nil,
- 'r_reason_sk .int,
- 'r_reason_id .string,
- 'r_reason_desc .string),
- Table("ship_mode",
+ 'r_reason_sk.int,
+ 'r_reason_id.string,
+ 'r_reason_desc.string
+ ),
+ Table(
+ "ship_mode",
partitionColumns = Nil,
- 'sm_ship_mode_sk .int,
- 'sm_ship_mode_id .string,
- 'sm_type .string,
- 'sm_code .string,
- 'sm_carrier .string,
- 'sm_contract .string),
- Table("store",
+ 'sm_ship_mode_sk.int,
+ 'sm_ship_mode_id.string,
+ 'sm_type.string,
+ 'sm_code.string,
+ 'sm_carrier.string,
+ 'sm_contract.string
+ ),
+ Table(
+ "store",
partitionColumns = Nil,
- 's_store_sk .int,
- 's_store_id .string,
- 's_rec_start_date .date,
- 's_rec_end_date .date,
- 's_closed_date_sk .int,
- 's_store_name .string,
- 's_number_employees .int,
- 's_floor_space .int,
- 's_hours .string,
- 's_manager .string,
- 's_market_id .int,
- 's_geography_class .string,
- 's_market_desc .string,
- 's_market_manager .string,
- 's_division_id .int,
- 's_division_name .string,
- 's_company_id .int,
- 's_company_name .string,
- 's_street_number .string,
- 's_street_name .string,
- 's_street_type .string,
- 's_suite_number .string,
- 's_city .string,
- 's_county .string,
- 's_state .string,
- 's_zip .string,
- 's_country .string,
- 's_gmt_offset .decimal(5,2),
- 's_tax_precentage .decimal(5,2)),
- Table("time_dim",
+ 's_store_sk.int,
+ 's_store_id.string,
+ 's_rec_start_date.date,
+ 's_rec_end_date.date,
+ 's_closed_date_sk.int,
+ 's_store_name.string,
+ 's_number_employees.int,
+ 's_floor_space.int,
+ 's_hours.string,
+ 's_manager.string,
+ 's_market_id.int,
+ 's_geography_class.string,
+ 's_market_desc.string,
+ 's_market_manager.string,
+ 's_division_id.int,
+ 's_division_name.string,
+ 's_company_id.int,
+ 's_company_name.string,
+ 's_street_number.string,
+ 's_street_name.string,
+ 's_street_type.string,
+ 's_suite_number.string,
+ 's_city.string,
+ 's_county.string,
+ 's_state.string,
+ 's_zip.string,
+ 's_country.string,
+ 's_gmt_offset.decimal(5, 2),
+ 's_tax_precentage.decimal(5, 2)
+ ),
+ Table(
+ "time_dim",
partitionColumns = Nil,
- 't_time_sk .int,
- 't_time_id .string,
- 't_time .int,
- 't_hour .int,
- 't_minute .int,
- 't_second .int,
- 't_am_pm .string,
- 't_shift .string,
- 't_sub_shift .string,
- 't_meal_time .string),
- Table("warehouse",
+ 't_time_sk.int,
+ 't_time_id.string,
+ 't_time.int,
+ 't_hour.int,
+ 't_minute.int,
+ 't_second.int,
+ 't_am_pm.string,
+ 't_shift.string,
+ 't_sub_shift.string,
+ 't_meal_time.string
+ ),
+ Table(
+ "warehouse",
partitionColumns = Nil,
- 'w_warehouse_sk .int,
- 'w_warehouse_id .string,
- 'w_warehouse_name .string,
- 'w_warehouse_sq_ft .int,
- 'w_street_number .string,
- 'w_street_name .string,
- 'w_street_type .string,
- 'w_suite_number .string,
- 'w_city .string,
- 'w_county .string,
- 'w_state .string,
- 'w_zip .string,
- 'w_country .string,
- 'w_gmt_offset .decimal(5,2)),
- Table("web_page",
+ 'w_warehouse_sk.int,
+ 'w_warehouse_id.string,
+ 'w_warehouse_name.string,
+ 'w_warehouse_sq_ft.int,
+ 'w_street_number.string,
+ 'w_street_name.string,
+ 'w_street_type.string,
+ 'w_suite_number.string,
+ 'w_city.string,
+ 'w_county.string,
+ 'w_state.string,
+ 'w_zip.string,
+ 'w_country.string,
+ 'w_gmt_offset.decimal(5, 2)
+ ),
+ Table(
+ "web_page",
partitionColumns = Nil,
- 'wp_web_page_sk .int,
- 'wp_web_page_id .string,
- 'wp_rec_start_date .date,
- 'wp_rec_end_date .date,
- 'wp_creation_date_sk .int,
- 'wp_access_date_sk .int,
- 'wp_autogen_flag .string,
- 'wp_customer_sk .int,
- 'wp_url .string,
- 'wp_type .string,
- 'wp_char_count .int,
- 'wp_link_count .int,
- 'wp_image_count .int,
- 'wp_max_ad_count .int),
- Table("web_site",
+ 'wp_web_page_sk.int,
+ 'wp_web_page_id.string,
+ 'wp_rec_start_date.date,
+ 'wp_rec_end_date.date,
+ 'wp_creation_date_sk.int,
+ 'wp_access_date_sk.int,
+ 'wp_autogen_flag.string,
+ 'wp_customer_sk.int,
+ 'wp_url.string,
+ 'wp_type.string,
+ 'wp_char_count.int,
+ 'wp_link_count.int,
+ 'wp_image_count.int,
+ 'wp_max_ad_count.int
+ ),
+ Table(
+ "web_site",
partitionColumns = Nil,
- 'web_site_sk .int,
- 'web_site_id .string,
- 'web_rec_start_date .date,
- 'web_rec_end_date .date,
- 'web_name .string,
- 'web_open_date_sk .int,
- 'web_close_date_sk .int,
- 'web_class .string,
- 'web_manager .string,
- 'web_mkt_id .int,
- 'web_mkt_class .string,
- 'web_mkt_desc .string,
- 'web_market_manager .string,
- 'web_company_id .int,
- 'web_company_name .string,
- 'web_street_number .string,
- 'web_street_name .string,
- 'web_street_type .string,
- 'web_suite_number .string,
- 'web_city .string,
- 'web_county .string,
- 'web_state .string,
- 'web_zip .string,
- 'web_country .string,
- 'web_gmt_offset .decimal(5,2),
- 'web_tax_percentage .decimal(5,2))
+ 'web_site_sk.int,
+ 'web_site_id.string,
+ 'web_rec_start_date.date,
+ 'web_rec_end_date.date,
+ 'web_name.string,
+ 'web_open_date_sk.int,
+ 'web_close_date_sk.int,
+ 'web_class.string,
+ 'web_manager.string,
+ 'web_mkt_id.int,
+ 'web_mkt_class.string,
+ 'web_mkt_desc.string,
+ 'web_market_manager.string,
+ 'web_company_id.int,
+ 'web_company_name.string,
+ 'web_street_number.string,
+ 'web_street_name.string,
+ 'web_street_type.string,
+ 'web_suite_number.string,
+ 'web_city.string,
+ 'web_county.string,
+ 'web_state.string,
+ 'web_zip.string,
+ 'web_country.string,
+ 'web_gmt_offset.decimal(5, 2),
+ 'web_tax_percentage.decimal(5, 2)
+ )
).map(_.convertTypes())
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDS_1_4_Queries.scala b/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDS_1_4_Queries.scala
index 55196787..44ba25b6 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDS_1_4_Queries.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDS_1_4_Queries.scala
@@ -18,11 +18,9 @@ package com.databricks.spark.sql.perf.tpcds
import com.databricks.spark.sql.perf.{Benchmark, ExecutionMode, Query}
-/**
- * This implements the official TPCDS v1.4 queries with only cosmetic modifications
- * (noted for each query).
- * Don't modify this except for these kind of modifications.
- */
+/** This implements the official TPCDS v1.4 queries with only cosmetic modifications (noted for each
+ * query). Don't modify this except for these kind of modifications.
+ */
trait Tpcds_1_4_Queries extends Benchmark {
import ExecutionMode._
@@ -33,7 +31,9 @@ trait Tpcds_1_4_Queries extends Benchmark {
// Queries the TPCDS 1.4 queries using the qualifcations values in the templates.
val tpcds1_4Queries = Seq(
- ("q1", """
+ (
+ "q1",
+ """
| WITH customer_total_return AS
| (SELECT sr_customer_sk AS ctr_customer_sk, sr_store_sk AS ctr_store_sk,
| sum(sr_return_amt) AS ctr_total_return
@@ -50,8 +50,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| AND s_state = 'TN'
| AND ctr1.ctr_customer_sk = c_customer_sk
| ORDER BY c_customer_id LIMIT 100
- """.stripMargin),
- ("q2", """
+ """.stripMargin
+ ),
+ (
+ "q2",
+ """
| WITH wscs as
| (SELECT sold_date_sk, sales_price
| FROM (SELECT ws_sold_date_sk sold_date_sk, ws_ext_sales_price sales_price
@@ -102,8 +105,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| WHERE date_dim.d_week_seq = wswscs.d_week_seq AND d_year = 2001 + 1) z
| WHERE d_week_seq1=d_week_seq2-53
| ORDER BY d_week_seq1
- """.stripMargin),
- ("q3", """
+ """.stripMargin
+ ),
+ (
+ "q3",
+ """
| SELECT dt.d_year, item.i_brand_id brand_id, item.i_brand brand,SUM(ss_ext_sales_price) sum_agg
| FROM date_dim dt, store_sales, item
| WHERE dt.d_date_sk = store_sales.ss_sold_date_sk
@@ -113,8 +119,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| GROUP BY dt.d_year, item.i_brand, item.i_brand_id
| ORDER BY dt.d_year, sum_agg desc, brand_id
| LIMIT 100
- """.stripMargin),
- ("q4", """
+ """.stripMargin
+ ),
+ (
+ "q4",
+ """
|WITH year_total AS (
| SELECT c_customer_id customer_id,
| c_first_name customer_first_name,
@@ -221,10 +230,13 @@ trait Tpcds_1_4_Queries extends Benchmark {
| t_s_secyear.customer_login,
| t_s_secyear.customer_email_address
| LIMIT 100
- """.stripMargin),
+ """.stripMargin
+ ),
// Modifications: "+ days" -> date_add
// Modifications: "||" -> concat
- ("q5", """
+ (
+ "q5",
+ """
| WITH ssr AS
| (SELECT s_store_id,
| sum(sales_price) as sales,
@@ -342,8 +354,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| GROUP BY ROLLUP (channel, id)
| ORDER BY channel, id
| LIMIT 100
- """.stripMargin),
- ("q6", """
+ """.stripMargin
+ ),
+ (
+ "q6",
+ """
| SELECT a.ca_state state, count(*) cnt
| FROM
| customer_address a, customer c, store_sales s, date_dim d, item i
@@ -360,8 +375,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| GROUP BY a.ca_state
| HAVING count(*) >= 10
| ORDER BY cnt LIMIT 100
- """.stripMargin),
- ("q7", """
+ """.stripMargin
+ ),
+ (
+ "q7",
+ """
| SELECT i_item_id,
| avg(ss_quantity) agg1,
| avg(ss_list_price) agg2,
@@ -379,8 +397,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| d_year = 2000
| GROUP BY i_item_id
| ORDER BY i_item_id LIMIT 100
- """.stripMargin),
- ("q8", """
+ """.stripMargin
+ ),
+ (
+ "q8",
+ """
| select s_store_name, sum(ss_net_profit)
| from store_sales, date_dim, store,
| (SELECT ca_zip
@@ -462,8 +483,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| and (substr(s_zip,1,2) = substr(V1.ca_zip,1,2))
| group by s_store_name
| order by s_store_name LIMIT 100
- """.stripMargin),
- ("q9", s"""
+ """.stripMargin
+ ),
+ (
+ "q9",
+ s"""
|select case when (select count(*) from store_sales
| where ss_quantity between 1 and 20) > ${rc(0)}
| then (select avg(ss_ext_discount_amt) from store_sales
@@ -496,8 +520,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| where ss_quantity between 81 and 100) end bucket5
|from reason
|where r_reason_sk = 1
- """.stripMargin),
- ("q10", """
+ """.stripMargin
+ ),
+ (
+ "q10",
+ """
| select
| cd_gender, cd_marital_status, cd_education_status, count(*) cnt1,
| cd_purchase_estimate, count(*) cnt2, cd_credit_rating, count(*) cnt3,
@@ -542,8 +569,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| cd_dep_employed_count,
| cd_dep_college_count
|LIMIT 100
- """.stripMargin),
- ("q11", """
+ """.stripMargin
+ ),
+ (
+ "q11",
+ """
| with year_total as (
| select c_customer_id customer_id
| ,c_first_name customer_first_name
@@ -607,9 +637,12 @@ trait Tpcds_1_4_Queries extends Benchmark {
| > case when t_s_firstyear.year_total > 0 then t_s_secyear.year_total / t_s_firstyear.year_total else null end
| order by t_s_secyear.customer_preferred_cust_flag
| LIMIT 100
- """.stripMargin),
+ """.stripMargin
+ ),
// Modifications: "+ days" -> date_add
- ("q12", """
+ (
+ "q12",
+ """
| select
| i_item_desc, i_category, i_class, i_current_price,
| sum(ws_ext_sales_price) as itemrevenue,
@@ -628,8 +661,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| order by
| i_category, i_class, i_item_id, i_item_desc, revenueratio
| LIMIT 100
- """.stripMargin),
- ("q13", """
+ """.stripMargin
+ ),
+ (
+ "q13",
+ """
| select avg(ss_quantity)
| ,avg(ss_ext_sales_price)
| ,avg(ss_ext_wholesale_cost)
@@ -678,8 +714,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| and ca_state in ('VA', 'TX', 'MS')
| and ss_net_profit between 50 and 250
| ))
- """.stripMargin),
- ("q14a", """
+ """.stripMargin
+ ),
+ (
+ "q14a",
+ """
|with cross_items as
| (select i_item_sk ss_item_sk
| from item,
@@ -758,8 +797,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by rollup (channel, i_brand_id,i_class_id,i_category_id)
| order by channel,i_brand_id,i_class_id,i_category_id
| limit 100
- """.stripMargin),
- ("q14b", """
+ """.stripMargin
+ ),
+ (
+ "q14b",
+ """
| with cross_items as
| (select i_item_sk ss_item_sk
| from item,
@@ -823,8 +865,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| and this_year.i_category_id = last_year.i_category_id
| order by this_year.channel, this_year.i_brand_id, this_year.i_class_id, this_year.i_category_id
| limit 100
- """.stripMargin),
- ("q15", """
+ """.stripMargin
+ ),
+ (
+ "q15",
+ """
| select ca_zip, sum(cs_sales_price)
| from catalog_sales, customer, customer_address, date_dim
| where cs_bill_customer_sk = c_customer_sk
@@ -838,9 +883,12 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by ca_zip
| order by ca_zip
| limit 100
- """.stripMargin),
+ """.stripMargin
+ ),
// Modifications: " -> `
- ("q16", """
+ (
+ "q16",
+ """
| select
| count(distinct cs_order_number) as `order count`,
| sum(cs_ext_ship_cost) as `total shipping cost`,
@@ -863,8 +911,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| where cs1.cs_order_number = cr1.cr_order_number)
| order by count(distinct cs_order_number)
| limit 100
- """.stripMargin),
- ("q17", """
+ """.stripMargin
+ ),
+ (
+ "q17",
+ """
| select i_item_id
| ,i_item_desc
| ,s_state
@@ -896,9 +947,12 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by i_item_id, i_item_desc, s_state
| order by i_item_id, i_item_desc, s_state
| limit 100
- """.stripMargin),
+ """.stripMargin
+ ),
// Modifications: "numeric" -> "decimal"
- ("q18", """
+ (
+ "q18",
+ """
| select i_item_id,
| ca_country,
| ca_state,
@@ -926,8 +980,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by rollup (i_item_id, ca_country, ca_state, ca_county)
| order by ca_country, ca_state, ca_county, i_item_id
| LIMIT 100
- """.stripMargin),
- ("q19", """
+ """.stripMargin
+ ),
+ (
+ "q19",
+ """
| select i_brand_id brand_id, i_brand brand, i_manufact_id, i_manufact,
| sum(ss_ext_sales_price) ext_price
| from date_dim, store_sales, item,customer,customer_address,store
@@ -943,8 +1000,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by i_brand, i_brand_id, i_manufact_id, i_manufact
| order by ext_price desc, brand, brand_id, i_manufact_id, i_manufact
| limit 100
- """.stripMargin),
- ("q20", """
+ """.stripMargin
+ ),
+ (
+ "q20",
+ """
|select i_item_desc
| ,i_category
| ,i_class
@@ -961,9 +1021,12 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by i_item_id, i_item_desc, i_category, i_class, i_current_price
| order by i_category, i_class, i_item_id, i_item_desc, revenueratio
| limit 100
- """.stripMargin),
+ """.stripMargin
+ ),
// Modifications: "+ days" -> date_add
- ("q21", """
+ (
+ "q21",
+ """
| select * from(
| select w_warehouse_name, i_item_id,
| sum(case when (cast(d_date as date) < cast ('2000-03-11' as date))
@@ -986,8 +1049,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| end) between 2.0/3.0 and 3.0/2.0
| order by w_warehouse_name, i_item_id
| limit 100
- """.stripMargin),
- ("q22", """
+ """.stripMargin
+ ),
+ (
+ "q22",
+ """
| select i_product_name, i_brand, i_class, i_category, avg(inv_quantity_on_hand) qoh
| from inventory, date_dim, item, warehouse
| where inv_date_sk=d_date_sk
@@ -997,8 +1063,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by rollup(i_product_name, i_brand, i_class, i_category)
| order by qoh, i_product_name, i_brand, i_class, i_category
| limit 100
- """.stripMargin),
- ("q23a", """
+ """.stripMargin
+ ),
+ (
+ "q23a",
+ """
| with frequent_ss_items as
| (select substr(i_item_desc,1,30) itemdesc,i_item_sk item_sk,d_date solddate,count(*) cnt
| from store_sales, date_dim, item
@@ -1039,8 +1108,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| and ws_item_sk in (select item_sk from frequent_ss_items)
| and ws_bill_customer_sk in (select c_customer_sk from best_ss_customer))) y
| limit 100
- """.stripMargin),
- ("q23b", """
+ """.stripMargin
+ ),
+ (
+ "q23b",
+ """
|
| with frequent_ss_items as
| (select substr(i_item_desc,1,30) itemdesc,i_item_sk item_sk,d_date solddate,count(*) cnt
@@ -1088,8 +1160,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by c_last_name,c_first_name)) y
| order by c_last_name,c_first_name,sales
| limit 100
- """.stripMargin),
- ("q24a", """
+ """.stripMargin
+ ),
+ (
+ "q24a",
+ """
| with ssales as
| (select c_last_name, c_first_name, s_store_name, ca_state, s_state, i_color,
| i_current_price, i_manager_id, i_units, i_size, sum(ss_net_paid) netpaid
@@ -1109,8 +1184,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| where i_color = 'pale'
| group by c_last_name, c_first_name, s_store_name
| having sum(netpaid) > (select 0.05*avg(netpaid) from ssales)
- """.stripMargin),
- ("q24b", """
+ """.stripMargin
+ ),
+ (
+ "q24b",
+ """
| with ssales as
| (select c_last_name, c_first_name, s_store_name, ca_state, s_state, i_color,
| i_current_price, i_manager_id, i_units, i_size, sum(ss_net_paid) netpaid
@@ -1130,8 +1208,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| where i_color = 'chiffon'
| group by c_last_name, c_first_name, s_store_name
| having sum(netpaid) > (select 0.05*avg(netpaid) from ssales)
- """.stripMargin),
- ("q25", """
+ """.stripMargin
+ ),
+ (
+ "q25",
+ """
| select i_item_id, i_item_desc, s_store_id, s_store_name,
| sum(ss_net_profit) as store_sales_profit,
| sum(sr_net_loss) as store_returns_loss,
@@ -1161,8 +1242,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| order by
| i_item_id, i_item_desc, s_store_id, s_store_name
| limit 100
- """.stripMargin),
- ("q26", """
+ """.stripMargin
+ ),
+ (
+ "q26",
+ """
| select i_item_id,
| avg(cs_quantity) agg1,
| avg(cs_list_price) agg2,
@@ -1181,8 +1265,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by i_item_id
| order by i_item_id
| limit 100
- """.stripMargin),
- ("q27", """
+ """.stripMargin
+ ),
+ (
+ "q27",
+ """
| select i_item_id,
| s_state, grouping(s_state) g_state,
| avg(ss_quantity) agg1,
@@ -1202,8 +1289,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by rollup (i_item_id, s_state)
| order by i_item_id, s_state
| limit 100
- """.stripMargin),
- ("q28", """
+ """.stripMargin
+ ),
+ (
+ "q28",
+ """
| select *
| from (select avg(ss_list_price) B1_LP
| ,count(ss_list_price) B1_CNT
@@ -1254,8 +1344,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| or ss_coupon_amt between 7326 and 7326+1000
| or ss_wholesale_cost between 7 and 7+20)) B6
| limit 100
- """.stripMargin),
- ("q29", """
+ """.stripMargin
+ ),
+ (
+ "q29",
+ """
| select
| i_item_id
| ,i_item_desc
@@ -1288,8 +1381,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| order by
| i_item_id, i_item_desc, s_store_id, s_store_name
| limit 100
- """.stripMargin),
- ("q30", """
+ """.stripMargin
+ ),
+ (
+ "q30",
+ """
| with customer_total_return as
| (select wr_returning_customer_sk as ctr_customer_sk
| ,ca_state as ctr_state,
@@ -1313,8 +1409,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| ,c_birth_day,c_birth_month,c_birth_year,c_birth_country,c_login,c_email_address
| ,c_last_review_date,ctr_total_return
| limit 100
- """.stripMargin),
- ("q31", """
+ """.stripMargin
+ ),
+ (
+ "q31",
+ """
| with ss as
| (select ca_county,d_qoy, d_year,sum(ss_ext_sales_price) as store_sales
| from store_sales,date_dim,customer_address
@@ -1359,9 +1458,12 @@ trait Tpcds_1_4_Queries extends Benchmark {
| and case when ws2.web_sales > 0 then ws3.web_sales/ws2.web_sales else null end
| > case when ss2.store_sales > 0 then ss3.store_sales/ss2.store_sales else null end
| order by ss1.ca_county
- """.stripMargin),
+ """.stripMargin
+ ),
// Modifications: " -> `
- ("q32", """
+ (
+ "q32",
+ """
| select sum(cs_ext_discount_amt) as `excess discount amount`
| from
| catalog_sales, item, date_dim
@@ -1377,8 +1479,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| and d_date between '2000-01-27]' and (cast('2000-01-27' as date) + interval 90 days)
| and d_date_sk = cs_sold_date_sk)
|limit 100
- """.stripMargin),
- ("q33", """
+ """.stripMargin
+ ),
+ (
+ "q33",
+ """
| with ss as (
| select
| i_manufact_id,sum(ss_ext_sales_price) total_sales
@@ -1432,8 +1537,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by i_manufact_id
| order by total_sales
|limit 100
- """.stripMargin),
- ("q34", """
+ """.stripMargin
+ ),
+ (
+ "q34",
+ """
| select c_last_name, c_first_name, c_salutation, c_preferred_cust_flag, ss_ticket_number,
| cnt
| FROM
@@ -1457,8 +1565,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| where ss_customer_sk = c_customer_sk
| and cnt between 15 and 20
| order by c_last_name,c_first_name,c_salutation,c_preferred_cust_flag desc
- """.stripMargin),
- ("q35", """
+ """.stripMargin
+ ),
+ (
+ "q35",
+ """
| select
| ca_state,
| cd_gender,
@@ -1502,8 +1613,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| order by ca_state, cd_gender, cd_marital_status, cd_dep_count,
| cd_dep_employed_count, cd_dep_college_count
| limit 100
- """.stripMargin),
- ("q36", """
+ """.stripMargin
+ ),
+ (
+ "q36",
+ """
| select
| sum(ss_net_profit)/sum(ss_ext_sales_price) as gross_margin
| ,i_category
@@ -1527,9 +1641,12 @@ trait Tpcds_1_4_Queries extends Benchmark {
| ,case when lochierarchy = 0 then i_category end
| ,rank_within_parent
| limit 100
- """.stripMargin),
+ """.stripMargin
+ ),
// Modifications: "+ days" -> date_add
- ("q37", """
+ (
+ "q37",
+ """
| select i_item_id, i_item_desc, i_current_price
| from item, inventory, date_dim, catalog_sales
| where i_current_price between 68 and 68 + 30
@@ -1542,8 +1659,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by i_item_id,i_item_desc,i_current_price
| order by i_item_id
| limit 100
- """.stripMargin),
- ("q38", """
+ """.stripMargin
+ ),
+ (
+ "q38",
+ """
| select count(*) from (
| select distinct c_last_name, c_first_name, d_date
| from store_sales, date_dim, customer
@@ -1564,8 +1684,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| and d_month_seq between 1200 and 1200 + 11
| ) hot_cust
| limit 100
- """.stripMargin),
- ("q39a", """
+ """.stripMargin
+ ),
+ (
+ "q39a",
+ """
| with inv as
| (select w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy
| ,stdev,mean, case mean when 0 then null else stdev/mean end cov
@@ -1587,8 +1710,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| and inv2.d_moy=1+1
| order by inv1.w_warehouse_sk,inv1.i_item_sk,inv1.d_moy,inv1.mean,inv1.cov
| ,inv2.d_moy,inv2.mean, inv2.cov
- """.stripMargin),
- ("q39b", """
+ """.stripMargin
+ ),
+ (
+ "q39b",
+ """
| with inv as
| (select w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy
| ,stdev,mean, case mean when 0 then null else stdev/mean end cov
@@ -1611,9 +1737,12 @@ trait Tpcds_1_4_Queries extends Benchmark {
| and inv1.cov > 1.5
| order by inv1.w_warehouse_sk,inv1.i_item_sk,inv1.d_moy,inv1.mean,inv1.cov
| ,inv2.d_moy,inv2.mean, inv2.cov
- """.stripMargin),
+ """.stripMargin
+ ),
// Modifications: "+ days" -> date_add
- ("q40", """
+ (
+ "q40",
+ """
| select
| w_state
| ,i_item_id
@@ -1636,8 +1765,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by w_state,i_item_id
| order by w_state,i_item_id
| limit 100
- """.stripMargin),
- ("q41", """
+ """.stripMargin
+ ),
+ (
+ "q41",
+ """
| select distinct(i_product_name)
| from item i1
| where i_manufact_id between 738 and 738+40
@@ -1687,8 +1819,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| )))) > 0
| order by i_product_name
| limit 100
- """.stripMargin),
- ("q42", """
+ """.stripMargin
+ ),
+ (
+ "q42",
+ """
| select dt.d_year, item.i_category_id, item.i_category, sum(ss_ext_sales_price)
| from date_dim dt, store_sales, item
| where dt.d_date_sk = store_sales.ss_sold_date_sk
@@ -1703,8 +1838,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| ,item.i_category_id
| ,item.i_category
| limit 100
- """.stripMargin),
- ("q43", """
+ """.stripMargin
+ ),
+ (
+ "q43",
+ """
| select s_store_name, s_store_id,
| sum(case when (d_day_name='Sunday') then ss_sales_price else null end) sun_sales,
| sum(case when (d_day_name='Monday') then ss_sales_price else null end) mon_sales,
@@ -1722,8 +1860,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| order by s_store_name, s_store_id,sun_sales,mon_sales,tue_sales,wed_sales,
| thu_sales,fri_sales,sat_sales
| limit 100
- """.stripMargin),
- ("q44", """
+ """.stripMargin
+ ),
+ (
+ "q44",
+ """
| select asceding.rnk, i1.i_product_name best_performing, i2.i_product_name worst_performing
| from(select *
| from (select item_sk,rank() over (order by rank_col asc) rnk
@@ -1755,8 +1896,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| and i2.i_item_sk=descending.item_sk
| order by asceding.rnk
| limit 100
- """.stripMargin),
- ("q45", """
+ """.stripMargin
+ ),
+ (
+ "q45",
+ """
| select ca_zip, ca_city, sum(ws_sales_price)
| from web_sales, customer, customer_address, date_dim, item
| where ws_bill_customer_sk = c_customer_sk
@@ -1774,8 +1918,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by ca_zip, ca_city
| order by ca_zip, ca_city
| limit 100
- """.stripMargin),
- ("q46", """
+ """.stripMargin
+ ),
+ (
+ "q46",
+ """
| select c_last_name, c_first_name, ca_city, bought_city, ss_ticket_number, amt,profit
| from
| (select ss_ticket_number
@@ -1799,8 +1946,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| and current_addr.ca_city <> bought_city
| order by c_last_name, c_first_name, ca_city, bought_city, ss_ticket_number
| limit 100
- """.stripMargin),
- ("q47", """
+ """.stripMargin
+ ),
+ (
+ "q47",
+ """
| with v1 as(
| select i_category, i_brand,
| s_store_name, s_company_name,
@@ -1847,8 +1997,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| case when avg_monthly_sales > 0 then abs(sum_sales - avg_monthly_sales) / avg_monthly_sales else null end > 0.1
| order by sum_sales - avg_monthly_sales, 3
| limit 100
- """.stripMargin),
- ("q48", """
+ """.stripMargin
+ ),
+ (
+ "q48",
+ """
| select sum (ss_quantity)
| from store_sales, store, customer_demographics, customer_address, date_dim
| where s_store_sk = ss_store_sk
@@ -1912,9 +2065,12 @@ trait Tpcds_1_4_Queries extends Benchmark {
| and ss_net_profit between 50 and 25000
| )
| )
- """.stripMargin),
+ """.stripMargin
+ ),
// Modifications: "dec" -> "decimal"
- ("q49", """
+ (
+ "q49",
+ """
| select 'web' as channel, web.item, web.return_ratio, web.return_rank, web.currency_rank
| from (
| select
@@ -2010,9 +2166,12 @@ trait Tpcds_1_4_Queries extends Benchmark {
| where (store.return_rank <= 10 or store.currency_rank <= 10)
| order by 1,4,5
| limit 100
- """.stripMargin),
+ """.stripMargin
+ ),
// Modifications: " -> `
- ("q50", """
+ (
+ "q50",
+ """
| select
| s_store_name, s_company_id, s_street_number, s_street_name, s_street_type,
| s_suite_number, s_city, s_county, s_state, s_zip
@@ -2042,8 +2201,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| s_store_name, s_company_id, s_street_number, s_street_name, s_street_type,
| s_suite_number, s_city, s_county, s_state, s_zip
| limit 100
- """.stripMargin),
- ("q51", """
+ """.stripMargin
+ ),
+ (
+ "q51",
+ """
| WITH web_v1 as (
| select
| ws_item_sk item_sk, d_date,
@@ -2080,8 +2242,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| where web_cumulative > store_cumulative
| order by item_sk, d_date
| limit 100
- """.stripMargin),
- ("q52", """
+ """.stripMargin
+ ),
+ (
+ "q52",
+ """
| select dt.d_year
| ,item.i_brand_id brand_id
| ,item.i_brand brand
@@ -2095,8 +2260,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by dt.d_year, item.i_brand, item.i_brand_id
| order by dt.d_year, ext_price desc, brand_id
|limit 100
- """.stripMargin),
- ("q53", """
+ """.stripMargin
+ ),
+ (
+ "q53",
+ """
| select * from
| (select i_manufact_id,
| sum(ss_sales_price) sum_sales,
@@ -2124,8 +2292,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| sum_sales,
| i_manufact_id
| limit 100
- """.stripMargin),
- ("q54", """
+ """.stripMargin
+ ),
+ (
+ "q54",
+ """
| with my_customers as (
| select distinct c_customer_sk
| , c_current_addr_sk
@@ -2177,8 +2348,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by segment
| order by segment, num_customers
| limit 100
- """.stripMargin),
- ("q55", """
+ """.stripMargin
+ ),
+ (
+ "q55",
+ """
|select i_brand_id brand_id, i_brand brand,
| sum(ss_ext_sales_price) ext_price
| from date_dim, store_sales, item
@@ -2190,8 +2364,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by i_brand, i_brand_id
| order by ext_price desc, brand_id
| limit 100
- """.stripMargin),
- ("q56", """
+ """.stripMargin
+ ),
+ (
+ "q56",
+ """
| with ss as (
| select i_item_id,sum(ss_ext_sales_price) total_sales
| from
@@ -2240,8 +2417,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by i_item_id
| order by total_sales
| limit 100
- """.stripMargin),
- ("q57", """
+ """.stripMargin
+ ),
+ (
+ "q57",
+ """
| with v1 as(
| select i_category, i_brand,
| cc_name,
@@ -2283,8 +2463,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| case when avg_monthly_sales > 0 then abs(sum_sales - avg_monthly_sales) / avg_monthly_sales else null end > 0.1
| order by sum_sales - avg_monthly_sales, 3
| limit 100
- """.stripMargin),
- ("q58", """
+ """.stripMargin
+ ),
+ (
+ "q58",
+ """
| with ss_items as
| (select i_item_id item_id, sum(ss_ext_sales_price) ss_item_rev
| from store_sales, item, date_dim
@@ -2338,8 +2521,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| and ws_item_rev between 0.9 * cs_item_rev and 1.1 * cs_item_rev
| order by item_id, ss_item_rev
| limit 100
- """.stripMargin),
- ("q59", """
+ """.stripMargin
+ ),
+ (
+ "q59",
+ """
| with wss as
| (select d_week_seq,
| ss_store_sk,
@@ -2381,8 +2567,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| and d_week_seq1=d_week_seq2-52
| order by s_store_name1,s_store_id1,d_week_seq1
| limit 100
- """.stripMargin),
- ("q60", """
+ """.stripMargin
+ ),
+ (
+ "q60",
+ """
| with ss as (
| select i_item_id,sum(ss_ext_sales_price) total_sales
| from store_sales, date_dim, customer_address, item
@@ -2428,8 +2617,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by i_item_id
| order by i_item_id, total_sales
| limit 100
- """.stripMargin),
- ("q61", s"""
+ """.stripMargin
+ ),
+ (
+ "q61",
+ s"""
| select promotions,total,cast(promotions as decimal(15,4))/cast(total as decimal(15,4))*100
| from
| (select sum(ss_ext_sales_price) promotions
@@ -2460,9 +2652,12 @@ trait Tpcds_1_4_Queries extends Benchmark {
| and d_moy = 11) all_sales
| order by promotions, total
| limit 100
- """.stripMargin),
+ """.stripMargin
+ ),
// Modifications: " -> `
- ("q62", """
+ (
+ "q62",
+ """
| select
| substr(w_warehouse_name,1,20)
| ,sm_type
@@ -2488,8 +2683,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| order by
| substr(w_warehouse_name,1,20), sm_type, web_name
| limit 100
- """.stripMargin),
- ("q63", """
+ """.stripMargin
+ ),
+ (
+ "q63",
+ """
| select *
| from (select i_manager_id
| ,sum(ss_sales_price) sum_sales
@@ -2517,8 +2715,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| ,avg_monthly_sales
| ,sum_sales
| limit 100
- """.stripMargin),
- ("q64", """
+ """.stripMargin
+ ),
+ (
+ "q64",
+ """
| with cs_ui as
| (select cs_item_sk
| ,sum(cs_ext_list_price) as sale,sum(cr_refunded_cash+cr_reversed_charge+cr_store_credit) as refund
@@ -2576,8 +2777,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| cs1.store_name = cs2.store_name and
| cs1.store_zip = cs2.store_zip
| order by cs1.product_name, cs1.store_name, cs2.cnt
- """.stripMargin),
- ("q65", """
+ """.stripMargin
+ ),
+ (
+ "q65",
+ """
| select
| s_store_name, i_item_desc, sc.revenue, i_current_price, i_wholesale_cost, i_brand
| from store, item,
@@ -2599,9 +2803,12 @@ trait Tpcds_1_4_Queries extends Benchmark {
| i_item_sk = sc.ss_item_sk
| order by s_store_name, i_item_desc
| limit 100
- """.stripMargin),
+ """.stripMargin
+ ),
// Modifications: "||" -> concat
- ("q66", """
+ (
+ "q66",
+ """
| select w_warehouse_name, w_warehouse_sq_ft, w_city, w_county, w_state, w_country,
| ship_carriers, year
| ,sum(jan_sales) as jan_sales
@@ -2728,8 +2935,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| ship_carriers, year
| order by w_warehouse_name
| limit 100
- """.stripMargin),
- ("q67", """
+ """.stripMargin
+ ),
+ (
+ "q67",
+ """
| select * from
| (select i_category, i_class, i_brand, i_product_name, d_year, d_qoy, d_moy, s_store_id,
| sumsales, rank() over (partition by i_category order by sumsales desc) rk
@@ -2748,8 +2958,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| i_category, i_class, i_brand, i_product_name, d_year,
| d_qoy, d_moy, s_store_id, sumsales, rk
| limit 100
- """.stripMargin),
- ("q68", """
+ """.stripMargin
+ ),
+ (
+ "q68",
+ """
| select
| c_last_name, c_first_name, ca_city, bought_city, ss_ticket_number, extended_price,
| extended_tax, list_price
@@ -2776,8 +2989,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| and current_addr.ca_city <> bought_city
| order by c_last_name, ss_ticket_number
| limit 100
- """.stripMargin),
- ("q69", """
+ """.stripMargin
+ ),
+ (
+ "q69",
+ """
| select
| cd_gender, cd_marital_status, cd_education_status, count(*) cnt1,
| cd_purchase_estimate, count(*) cnt2, cd_credit_rating, count(*) cnt3
@@ -2807,8 +3023,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| order by cd_gender, cd_marital_status, cd_education_status,
| cd_purchase_estimate, cd_credit_rating
| limit 100
- """.stripMargin),
- ("q70", """
+ """.stripMargin
+ ),
+ (
+ "q70",
+ """
| select
| sum(ss_net_profit) as total_sum, s_state, s_county
| ,grouping(s_state)+grouping(s_county) as lochierarchy
@@ -2838,8 +3057,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| ,case when lochierarchy = 0 then s_state end
| ,rank_within_parent
| limit 100
- """.stripMargin),
- ("q71", """
+ """.stripMargin
+ ),
+ (
+ "q71",
+ """
| select i_brand_id brand_id, i_brand brand,t_hour,t_minute,
| sum(ext_price) ext_price
| from item,
@@ -2880,9 +3102,12 @@ trait Tpcds_1_4_Queries extends Benchmark {
| and (t_meal_time = 'breakfast' or t_meal_time = 'dinner')
| group by i_brand, i_brand_id,t_hour,t_minute
| order by ext_price desc, brand_id
- """.stripMargin),
+ """.stripMargin
+ ),
// Modifications: "+ days" -> date_add
- ("q72", """
+ (
+ "q72",
+ """
| select i_item_desc
| ,w_warehouse_name
| ,d1.d_week_seq
@@ -2911,8 +3136,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by i_item_desc,w_warehouse_name,d1.d_week_seq
| order by total_cnt desc, i_item_desc, w_warehouse_name, d_week_seq
| limit 100
- """.stripMargin),
- ("q73", """
+ """.stripMargin
+ ),
+ (
+ "q73",
+ """
| select
| c_last_name, c_first_name, c_salutation, c_preferred_cust_flag,
| ss_ticket_number, cnt from
@@ -2933,8 +3161,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| where ss_customer_sk = c_customer_sk
| and cnt between 1 and 5
| order by cnt desc
- """.stripMargin),
- ("q74", """
+ """.stripMargin
+ ),
+ (
+ "q74",
+ """
| with year_total as (
| select
| c_customer_id customer_id, c_first_name customer_first_name,
@@ -2981,8 +3212,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| > case when t_s_firstyear.year_total > 0 then t_s_secyear.year_total / t_s_firstyear.year_total else null end
| order by 1, 1, 1
| limit 100
- """.stripMargin),
- ("q75", """
+ """.stripMargin
+ ),
+ (
+ "q75",
+ """
| WITH all_sales AS (
| SELECT
| d_year, i_brand_id, i_class_id, i_category_id, i_manufact_id,
@@ -3037,8 +3271,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| AND CAST(curr_yr.sales_cnt AS DECIMAL(17,2))/CAST(prev_yr.sales_cnt AS DECIMAL(17,2))<0.9
| ORDER BY sales_cnt_diff
| LIMIT 100
- """.stripMargin),
- ("q76", """
+ """.stripMargin
+ ),
+ (
+ "q76",
+ """
| SELECT
| channel, col_name, d_year, d_qoy, i_category, COUNT(*) sales_cnt,
| SUM(ext_sales_price) sales_amt
@@ -3069,9 +3306,12 @@ trait Tpcds_1_4_Queries extends Benchmark {
| GROUP BY channel, col_name, d_year, d_qoy, i_category
| ORDER BY channel, col_name, d_year, d_qoy, i_category
| limit 100
- """.stripMargin),
+ """.stripMargin
+ ),
// Modifications: "+ days" -> date_add
- ("q77", """
+ (
+ "q77",
+ """
| with ss as
| (select s_store_sk, sum(ss_ext_sales_price) as sales, sum(ss_net_profit) as profit
| from store_sales, date_dim, store
@@ -3139,8 +3379,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by rollup(channel, id)
| order by channel, id
| limit 100
- """.stripMargin),
- ("q78", """
+ """.stripMargin
+ ),
+ (
+ "q78",
+ """
| with ws as
| (select d_year AS ws_sold_year, ws_item_sk,
| ws_bill_customer_sk ws_customer_sk,
@@ -3195,8 +3438,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| other_chan_sales_price,
| round(ss_qty/(coalesce(ws_qty+cs_qty,1)),2)
| limit 100
- """.stripMargin),
- ("q79", """
+ """.stripMargin
+ ),
+ (
+ "q79",
+ """
| select
| c_last_name,c_first_name,substr(s_city,1,30),ss_ticket_number,amt,profit
| from
@@ -3218,10 +3464,13 @@ trait Tpcds_1_4_Queries extends Benchmark {
| where ss_customer_sk = c_customer_sk
| order by c_last_name,c_first_name,substr(s_city,1,30), profit
| limit 100
- """.stripMargin),
+ """.stripMargin
+ ),
// Modifications: "+ days" -> date_add
// Modifications: "||" -> "concat"
- ("q80", """
+ (
+ "q80",
+ """
| with ssr as
| (select s_store_id as store_id,
| sum(ss_ext_sales_price) as sales,
@@ -3289,8 +3538,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by rollup (channel, id)
| order by channel, id
| limit 100
- """.stripMargin),
- ("q81", """
+ """.stripMargin
+ ),
+ (
+ "q81",
+ """
| with customer_total_return as
| (select
| cr_returning_customer_sk as ctr_customer_sk, ca_state as ctr_state,
@@ -3315,8 +3567,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| ,ca_street_type,ca_suite_number,ca_city,ca_county,ca_state,ca_zip,ca_country,ca_gmt_offset
| ,ca_location_type,ctr_total_return
| limit 100
- """.stripMargin),
- ("q82", """
+ """.stripMargin
+ ),
+ (
+ "q82",
+ """
| select i_item_id, i_item_desc, i_current_price
| from item, inventory, date_dim, store_sales
| where i_current_price between 62 and 62+30
@@ -3329,8 +3584,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by i_item_id,i_item_desc,i_current_price
| order by i_item_id
| limit 100
- """.stripMargin),
- ("q83", """
+ """.stripMargin
+ ),
+ (
+ "q83",
+ """
| with sr_items as
| (select i_item_id item_id, sum(sr_return_quantity) sr_item_qty
| from store_returns, item, date_dim
@@ -3368,9 +3626,12 @@ trait Tpcds_1_4_Queries extends Benchmark {
| and sr_items.item_id=wr_items.item_id
| order by sr_items.item_id, sr_item_qty
| limit 100
- """.stripMargin),
+ """.stripMargin
+ ),
// Modifications: "||" -> concat
- ("q84", """
+ (
+ "q84",
+ """
| select c_customer_id as customer_id
| ,concat(c_last_name, ', ', c_first_name) as customername
| from customer
@@ -3389,8 +3650,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| and sr_cdemo_sk = cd_demo_sk
| order by c_customer_id
| limit 100
- """.stripMargin),
- ("q85", """
+ """.stripMargin
+ ),
+ (
+ "q85",
+ """
| select
| substr(r_reason_desc,1,20), avg(ws_quantity), avg(wr_refunded_cash), avg(wr_fee)
| from web_sales, web_returns, web_page, customer_demographics cd1,
@@ -3470,8 +3734,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| ,avg(wr_refunded_cash)
| ,avg(wr_fee)
| limit 100
- """.stripMargin),
- ("q86", """
+ """.stripMargin
+ ),
+ (
+ "q86",
+ """
| select sum(ws_net_paid) as total_sum, i_category, i_class,
| grouping(i_category)+grouping(i_class) as lochierarchy,
| rank() over (
@@ -3490,8 +3757,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| case when lochierarchy = 0 then i_category end,
| rank_within_parent
| limit 100
- """.stripMargin),
- ("q87", """
+ """.stripMargin
+ ),
+ (
+ "q87",
+ """
| select count(*)
| from ((select distinct c_last_name, c_first_name, d_date
| from store_sales, date_dim, customer
@@ -3511,8 +3781,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| and web_sales.ws_bill_customer_sk = customer.c_customer_sk
| and d_month_seq between 1200 and 1200+11)
|) cool_cust
- """.stripMargin),
- ("q88", """
+ """.stripMargin
+ ),
+ (
+ "q88",
+ """
| select *
| from
| (select count(*) h8_30_to_9
@@ -3603,8 +3876,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| (household_demographics.hd_dep_count = 2 and household_demographics.hd_vehicle_count<=2+2) or
| (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2))
| and store.s_store_name = 'ese') s8
- """.stripMargin),
- ("q89", """
+ """.stripMargin
+ ),
+ (
+ "q89",
+ """
| select *
| from(
| select i_category, i_class, i_brand,
@@ -3628,8 +3904,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| where case when (avg_monthly_sales <> 0) then (abs(sum_sales - avg_monthly_sales) / avg_monthly_sales) else null end > 0.1
| order by sum_sales - avg_monthly_sales, s_store_name
| limit 100
- """.stripMargin),
- ("q90", """
+ """.stripMargin
+ ),
+ (
+ "q90",
+ """
| select cast(amc as decimal(15,4))/cast(pmc as decimal(15,4)) am_pm_ratio
| from ( select count(*) amc
| from web_sales, household_demographics , time_dim, web_page
@@ -3649,8 +3928,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| and web_page.wp_char_count between 5000 and 5200) pt
| order by am_pm_ratio
| limit 100
- """.stripMargin),
- ("q91", """
+ """.stripMargin
+ ),
+ (
+ "q91",
+ """
| select
| cc_call_center_id Call_Center, cc_name Call_Center_Name, cc_manager Manager,
| sum(cr_net_loss) Returns_Loss
@@ -3672,10 +3954,13 @@ trait Tpcds_1_4_Queries extends Benchmark {
| and ca_gmt_offset = -7
| group by cc_call_center_id,cc_name,cc_manager,cd_marital_status,cd_education_status
| order by sum(cr_net_loss) desc
- """.stripMargin),
+ """.stripMargin
+ ),
// Modifications: "+ days" -> date_add
// Modifications: " -> `
- ("q92", """
+ (
+ "q92",
+ """
| select sum(ws_ext_discount_amt) as `Excess Discount Amount`
| from web_sales, item, date_dim
| where i_manufact_id = 350
@@ -3692,8 +3977,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| )
| order by sum(ws_ext_discount_amt)
| limit 100
- """.stripMargin),
- ("q93", """
+ """.stripMargin
+ ),
+ (
+ "q93",
+ """
| select ss_customer_sk, sum(act_sales) sumsales
| from (select
| ss_item_sk, ss_ticket_number, ss_customer_sk,
@@ -3707,10 +3995,13 @@ trait Tpcds_1_4_Queries extends Benchmark {
| group by ss_customer_sk
| order by sumsales, ss_customer_sk
| limit 100
- """.stripMargin),
+ """.stripMargin
+ ),
// Modifications: "+ days" -> date_add
// Modifications: " -> `
- ("q94", """
+ (
+ "q94",
+ """
| select
| count(distinct ws_order_number) as `order count`
| ,sum(ws_ext_ship_cost) as `total shipping cost`
@@ -3734,9 +4025,12 @@ trait Tpcds_1_4_Queries extends Benchmark {
| where ws1.ws_order_number = wr1.wr_order_number)
| order by count(distinct ws_order_number)
| limit 100
- """.stripMargin),
+ """.stripMargin
+ ),
// Modifications: "+ days" -> date_add
- ("q95", """
+ (
+ "q95",
+ """
| with ws_wh as
| (select ws1.ws_order_number,ws1.ws_warehouse_sk wh1,ws2.ws_warehouse_sk wh2
| from web_sales ws1,web_sales ws2
@@ -3763,8 +4057,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| where wr_order_number = ws_wh.ws_order_number)
| order by count(distinct ws_order_number)
| limit 100
- """.stripMargin),
- ("q96", """
+ """.stripMargin
+ ),
+ (
+ "q96",
+ """
| select count(*)
| from store_sales, household_demographics, time_dim, store
| where ss_sold_time_sk = time_dim.t_time_sk
@@ -3776,8 +4073,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| and store.s_store_name = 'ese'
| order by count(*)
| limit 100
- """.stripMargin),
- ("q97", """
+ """.stripMargin
+ ),
+ (
+ "q97",
+ """
| with ssci as (
| select ss_customer_sk customer_sk, ss_item_sk item_sk
| from store_sales,date_dim
@@ -3796,9 +4096,12 @@ trait Tpcds_1_4_Queries extends Benchmark {
| from ssci full outer join csci on (ssci.customer_sk=csci.customer_sk
| and ssci.item_sk = csci.item_sk)
| limit 100
- """.stripMargin),
+ """.stripMargin
+ ),
// Modifications: "+ days" -> date_add
- ("q98", """
+ (
+ "q98",
+ """
|select i_item_desc, i_category, i_class, i_current_price
| ,sum(ss_ext_sales_price) as itemrevenue
| ,sum(ss_ext_sales_price)*100/sum(sum(ss_ext_sales_price)) over
@@ -3815,9 +4118,12 @@ trait Tpcds_1_4_Queries extends Benchmark {
| i_item_id, i_item_desc, i_category, i_class, i_current_price
|order by
| i_category, i_class, i_item_id, i_item_desc, revenueratio
- """.stripMargin),
+ """.stripMargin
+ ),
// Modifications: " -> `
- ("q99", """
+ (
+ "q99",
+ """
| select
| substr(w_warehouse_name,1,20), sm_type, cc_name
| ,sum(case when (cs_ship_date_sk - cs_sold_date_sk <= 30 ) then 1 else 0 end) as `30 days`
@@ -3840,9 +4146,11 @@ trait Tpcds_1_4_Queries extends Benchmark {
| substr(w_warehouse_name,1,20), sm_type, cc_name
| order by substr(w_warehouse_name,1,20), sm_type, cc_name
| limit 100
- """.stripMargin),
- ("qSsMax",
- """
+ """.stripMargin
+ ),
+ (
+ "qSsMax",
+ """
|select
| count(*) as total,
| count(ss_sold_date_sk) as not_null_total,
@@ -3857,23 +4165,100 @@ trait Tpcds_1_4_Queries extends Benchmark {
| max(ss_store_sk) as max_ss_store_sk,
| max(ss_promo_sk) as max_ss_promo_sk
|from store_sales
- """.stripMargin)
- ).map { case (name, sqlText) =>
- Query(name + "-v1.4", sqlText, description = "TPCDS 1.4 Query", executionMode = CollectResults)
+ """.stripMargin
+ )
+ ).map {
+ case (name, sqlText) =>
+ Query(
+ name + "-v1.4",
+ sqlText,
+ description = "TPCDS 1.4 Query",
+ executionMode = CollectResults
+ )
}
val tpcds1_4QueriesMap = tpcds1_4Queries.map(q => q.name.split("-").get(0) -> q).toMap
val runnable: Seq[Query] = Seq(
- "q1", "q2", "q3", "q4", "q5", "q7", "q8", "q9",
- "q11", "q12", "q13", "q15", "q17", "q18", "q19",
- "q20", "q21", "q22", "q25", "q26", "q27", "q28", "q29",
- "q31", "q34", "q36", "q37", "q38", "q39a", "q39b",
- "q40", "q42", "q43", "q44", "q46", "q47", "q48", "q49",
- "q50", "q51", "q52", "q53", "q54", "q55", "q57", "q59",
- "q61", "q62", "q63", "q64", "q65", "q66", "q67", "q68",
- "q71", "q72", "q73", "q74", "q75", "q76", "q77", "q78", "q79",
- "q80", "q82", "q84", "q85", "q86", "q87", "q88", "q89",
- "q90", "q91", "q93", "q96", "q97", "q98", "q99", "qSsMax").map(tpcds1_4QueriesMap)
+ "q1",
+ "q2",
+ "q3",
+ "q4",
+ "q5",
+ "q7",
+ "q8",
+ "q9",
+ "q11",
+ "q12",
+ "q13",
+ "q15",
+ "q17",
+ "q18",
+ "q19",
+ "q20",
+ "q21",
+ "q22",
+ "q25",
+ "q26",
+ "q27",
+ "q28",
+ "q29",
+ "q31",
+ "q34",
+ "q36",
+ "q37",
+ "q38",
+ "q39a",
+ "q39b",
+ "q40",
+ "q42",
+ "q43",
+ "q44",
+ "q46",
+ "q47",
+ "q48",
+ "q49",
+ "q50",
+ "q51",
+ "q52",
+ "q53",
+ "q54",
+ "q55",
+ "q57",
+ "q59",
+ "q61",
+ "q62",
+ "q63",
+ "q64",
+ "q65",
+ "q66",
+ "q67",
+ "q68",
+ "q71",
+ "q72",
+ "q73",
+ "q74",
+ "q75",
+ "q76",
+ "q77",
+ "q78",
+ "q79",
+ "q80",
+ "q82",
+ "q84",
+ "q85",
+ "q86",
+ "q87",
+ "q88",
+ "q89",
+ "q90",
+ "q91",
+ "q93",
+ "q96",
+ "q97",
+ "q98",
+ "q99",
+ "qSsMax"
+ ).map(tpcds1_4QueriesMap)
val all: Seq[Query] = tpcds1_4QueriesMap.values.toSeq
}
diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDS_2_4_Queries.scala b/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDS_2_4_Queries.scala
index f78dfe04..e2f997f4 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDS_2_4_Queries.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDS_2_4_Queries.scala
@@ -20,32 +20,128 @@ import org.apache.commons.io.IOUtils
import com.databricks.spark.sql.perf.{Benchmark, ExecutionMode, Query}
-/**
- * This implements the official TPCDS v2.4 queries with only cosmetic modifications.
- */
+/** This implements the official TPCDS v2.4 queries with only cosmetic modifications.
+ */
trait Tpcds_2_4_Queries extends Benchmark {
import ExecutionMode._
val queryNames = Seq(
- "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10",
- "q11", "q12", "q13", "q14a", "q14b", "q15", "q16", "q17", "q18", "q19",
- "q20", "q21", "q22", "q23a", "q23b", "q24a", "q24b", "q25", "q26", "q27",
- "q28", "q29", "q30", "q31", "q32", "q33", "q34", "q35", "q36", "q37",
- "q38", "q39a", "q39b", "q40", "q41", "q42", "q43", "q44", "q45", "q46", "q47",
- "q48", "q49", "q50", "q51", "q52", "q53", "q54", "q55", "q56", "q57", "q58",
- "q59", "q60", "q61", "q62", "q63", "q64", "q65", "q66", "q67", "q68", "q69",
- "q70", "q71", "q72", "q73", "q74", "q75", "q76", "q77", "q78", "q79",
- "q80", "q81", "q82", "q83", "q84", "q85", "q86", "q87", "q88", "q89",
- "q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99",
+ "q1",
+ "q2",
+ "q3",
+ "q4",
+ "q5",
+ "q6",
+ "q7",
+ "q8",
+ "q9",
+ "q10",
+ "q11",
+ "q12",
+ "q13",
+ "q14a",
+ "q14b",
+ "q15",
+ "q16",
+ "q17",
+ "q18",
+ "q19",
+ "q20",
+ "q21",
+ "q22",
+ "q23a",
+ "q23b",
+ "q24a",
+ "q24b",
+ "q25",
+ "q26",
+ "q27",
+ "q28",
+ "q29",
+ "q30",
+ "q31",
+ "q32",
+ "q33",
+ "q34",
+ "q35",
+ "q36",
+ "q37",
+ "q38",
+ "q39a",
+ "q39b",
+ "q40",
+ "q41",
+ "q42",
+ "q43",
+ "q44",
+ "q45",
+ "q46",
+ "q47",
+ "q48",
+ "q49",
+ "q50",
+ "q51",
+ "q52",
+ "q53",
+ "q54",
+ "q55",
+ "q56",
+ "q57",
+ "q58",
+ "q59",
+ "q60",
+ "q61",
+ "q62",
+ "q63",
+ "q64",
+ "q65",
+ "q66",
+ "q67",
+ "q68",
+ "q69",
+ "q70",
+ "q71",
+ "q72",
+ "q73",
+ "q74",
+ "q75",
+ "q76",
+ "q77",
+ "q78",
+ "q79",
+ "q80",
+ "q81",
+ "q82",
+ "q83",
+ "q84",
+ "q85",
+ "q86",
+ "q87",
+ "q88",
+ "q89",
+ "q90",
+ "q91",
+ "q92",
+ "q93",
+ "q94",
+ "q95",
+ "q96",
+ "q97",
+ "q98",
+ "q99",
"ss_max"
)
val tpcds2_4Queries = queryNames.map { queryName =>
- val queryContent: String = IOUtils.toString(
- getClass().getClassLoader().getResourceAsStream(s"tpcds_2_4/$queryName.sql"))
- Query(queryName + "-v2.4", queryContent, description = "TPCDS 2.4 Query",
- executionMode = CollectResults)
+ val queryContent: String =
+ IOUtils.toString(getClass().getClassLoader().getResourceAsStream(s"tpcds_2_4/$queryName.sql"))
+ Query(
+ queryName + "-v2.4",
+ queryContent,
+ description = "TPCDS 2.4 Query",
+ executionMode = CollectResults
+ )
}
val tpcds2_4QueriesMap = tpcds2_4Queries.map(q => q.name.split("-").get(0) -> q).toMap
diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpch/TPCH.scala b/src/main/scala/com/databricks/spark/sql/perf/tpch/TPCH.scala
index 5a23edf9..6206729c 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/tpch/TPCH.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/tpch/TPCH.scala
@@ -26,37 +26,41 @@ import org.apache.spark.sql.SQLContext
class DBGEN(dbgenDir: String, params: Seq[String]) extends DataGenerator {
val dbgen = s"$dbgenDir/dbgen"
- def generate(sparkContext: SparkContext,name: String, partitions: Int, scaleFactor: String) = {
- val smallTables = Seq("nation", "region")
+ def generate(sparkContext: SparkContext, name: String, partitions: Int, scaleFactor: String) = {
+ val smallTables = Seq("nation", "region")
val numPartitions = if (partitions > 1 && !smallTables.contains(name)) partitions else 1
- val generatedData = {
- sparkContext.parallelize(1 to numPartitions, numPartitions).flatMap { i =>
- val localToolsDir = if (new java.io.File(dbgen).exists) {
- dbgenDir
- } else if (new java.io.File(s"/$dbgenDir").exists) {
- s"/$dbgenDir"
- } else {
- sys.error(s"Could not find dbgen at $dbgen or /$dbgenDir. Run install")
+ val generatedData =
+ sparkContext
+ .parallelize(1 to numPartitions, numPartitions)
+ .flatMap { i =>
+ val localToolsDir = if (new java.io.File(dbgen).exists) {
+ dbgenDir
+ } else if (new java.io.File(s"/$dbgenDir").exists) {
+ s"/$dbgenDir"
+ } else {
+ sys.error(s"Could not find dbgen at $dbgen or /$dbgenDir. Run install")
+ }
+ val parallel = if (numPartitions > 1) s"-C $partitions -S $i" else ""
+ val shortTableNames = Map(
+ "customer" -> "c",
+ "lineitem" -> "L",
+ "nation" -> "n",
+ "orders" -> "O",
+ "part" -> "P",
+ "region" -> "r",
+ "supplier" -> "s",
+ "partsupp" -> "S"
+ )
+ val paramsString = params.mkString(" ")
+ val commands = Seq(
+ "bash",
+ "-c",
+ s"cd $localToolsDir && ./dbgen -q $paramsString -T ${shortTableNames(name)} -s $scaleFactor $parallel"
+ )
+ println(commands)
+ BlockingLineStream(commands)
}
- val parallel = if (numPartitions > 1) s"-C $partitions -S $i" else ""
- val shortTableNames = Map(
- "customer" -> "c",
- "lineitem" -> "L",
- "nation" -> "n",
- "orders" -> "O",
- "part" -> "P",
- "region" -> "r",
- "supplier" -> "s",
- "partsupp" -> "S"
- )
- val paramsString = params.mkString(" ")
- val commands = Seq(
- "bash", "-c",
- s"cd $localToolsDir && ./dbgen -q $paramsString -T ${shortTableNames(name)} -s $scaleFactor $parallel")
- println(commands)
- BlockingLineStream(commands)
- }.repartition(numPartitions)
- }
+ .repartition(numPartitions)
generatedData.setName(s"$name, sf=$scaleFactor, strings")
generatedData
@@ -69,14 +73,15 @@ class TPCHTables(
scaleFactor: String,
useDoubleForDecimal: Boolean = false,
useStringForDate: Boolean = false,
- generatorParams: Seq[String] = Nil)
- extends Tables(sqlContext, scaleFactor, useDoubleForDecimal, useStringForDate) {
- import sqlContext.implicits._
+ generatorParams: Seq[String] = Nil
+) extends Tables(sqlContext, scaleFactor, useDoubleForDecimal, useStringForDate) {
+ import spark.implicits._
val dataGenerator = new DBGEN(dbgenDir, generatorParams)
val tables = Seq(
- Table("part",
+ Table(
+ "part",
partitionColumns = "p_brand" :: Nil,
'p_partkey.long,
'p_name.string,
@@ -88,7 +93,8 @@ class TPCHTables(
'p_retailprice.decimal(12, 2),
'p_comment.string
),
- Table("supplier",
+ Table(
+ "supplier",
partitionColumns = Nil,
's_suppkey.long,
's_name.string,
@@ -98,7 +104,8 @@ class TPCHTables(
's_acctbal.decimal(12, 2),
's_comment.string
),
- Table("partsupp",
+ Table(
+ "partsupp",
partitionColumns = Nil,
'ps_partkey.long,
'ps_suppkey.long,
@@ -106,7 +113,8 @@ class TPCHTables(
'ps_supplycost.decimal(12, 2),
'ps_comment.string
),
- Table("customer",
+ Table(
+ "customer",
partitionColumns = "c_mktsegment" :: Nil,
'c_custkey.long,
'c_name.string,
@@ -117,7 +125,8 @@ class TPCHTables(
'c_mktsegment.string,
'c_comment.string
),
- Table("orders",
+ Table(
+ "orders",
partitionColumns = "o_orderdate" :: Nil,
'o_orderkey.long,
'o_custkey.long,
@@ -129,7 +138,8 @@ class TPCHTables(
'o_shippriority.int,
'o_comment.string
),
- Table("lineitem",
+ Table(
+ "lineitem",
partitionColumns = "l_shipdate" :: Nil,
'l_orderkey.long,
'l_partkey.long,
@@ -148,30 +158,24 @@ class TPCHTables(
'l_shipmode.string,
'l_comment.string
),
- Table("nation",
+ Table(
+ "nation",
partitionColumns = Nil,
'n_nationkey.long,
'n_name.string,
'n_regionkey.long,
'n_comment.string
),
- Table("region",
- partitionColumns = Nil,
- 'r_regionkey.long,
- 'r_name.string,
- 'r_comment.string
- )
+ Table("region", partitionColumns = Nil, 'r_regionkey.long, 'r_name.string, 'r_comment.string)
).map(_.convertTypes())
}
-class TPCH(@transient sqlContext: SQLContext)
- extends Benchmark(sqlContext) {
+class TPCH(@transient sqlContext: SQLContext) extends Benchmark(sqlContext) {
val queries = (1 to 22).map { q =>
- val queryContent: String = IOUtils.toString(
- getClass().getClassLoader().getResourceAsStream(s"tpch/queries/$q.sql"))
- Query(s"Q$q", queryContent, description = "TPCH Query",
- executionMode = CollectResults)
+ val queryContent: String =
+ IOUtils.toString(getClass().getClassLoader().getResourceAsStream(s"tpch/queries/$q.sql"))
+ Query(s"Q$q", queryContent, description = "TPCH Query", executionMode = CollectResults)
}
val queriesMap = queries.map(q => q.name.split("-").get(0) -> q).toMap
}
diff --git a/src/main/scala/org/apache/spark/ml/ModelBuilderSSP.scala b/src/main/scala/org/apache/spark/ml/ModelBuilderSSP.scala
index fa66e005..2abca96c 100644
--- a/src/main/scala/org/apache/spark/ml/ModelBuilderSSP.scala
+++ b/src/main/scala/org/apache/spark/ml/ModelBuilderSSP.scala
@@ -1,29 +1,31 @@
package org.apache.spark.ml
-import org.apache.spark.ml.classification.{ClassificationModelBuilder, DecisionTreeClassificationModel, LinearSVCModel, LogisticRegressionModel, NaiveBayesModel}
+import org.apache.spark.ml.classification.{
+ ClassificationModelBuilder,
+ DecisionTreeClassificationModel,
+ LinearSVCModel,
+ LogisticRegressionModel,
+ NaiveBayesModel
+}
import org.apache.spark.ml.linalg.{Matrix, Vector}
-import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, GeneralizedLinearRegressionModel, LinearRegressionModel}
+import org.apache.spark.ml.regression.{
+ DecisionTreeRegressionModel,
+ GeneralizedLinearRegressionModel,
+ LinearRegressionModel
+}
import org.apache.spark.ml.tree._
import org.apache.spark.mllib.random.RandomDataGenerator
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
-
-/**
- * Helper for creating MLlib models which have private constructors.
- */
+/** Helper for creating MLlib models which have private constructors.
+ */
object ModelBuilderSSP {
- def newLogisticRegressionModel(
- coefficients: Vector,
- intercept: Double): LogisticRegressionModel = {
+ def newLogisticRegressionModel(coefficients: Vector, intercept: Double): LogisticRegressionModel =
new LogisticRegressionModel("lr", coefficients, intercept)
.setThreshold(.5)
- }
-
- def newLinearRegressionModel(
- coefficients: Vector,
- intercept: Double): LinearRegressionModel = {
+ def newLinearRegressionModel(coefficients: Vector, intercept: Double): LinearRegressionModel = {
val model = new LinearRegressionModel("linr", coefficients, intercept)
if (model.hasParam("loss")) {
model.set(model.getParam("loss"), "squaredError")
@@ -31,30 +33,44 @@ object ModelBuilderSSP {
model
}
- def newGLR(
- coefficients: Vector,
- intercept: Double): GeneralizedLinearRegressionModel =
+ def newGLR(coefficients: Vector, intercept: Double): GeneralizedLinearRegressionModel =
new GeneralizedLinearRegressionModel("glr-uid", coefficients, intercept)
def newDecisionTreeClassificationModel(
depth: Int,
numClasses: Int,
featureArity: Array[Int],
- seed: Long): DecisionTreeClassificationModel = {
- require(numClasses >= 2, s"DecisionTreeClassificationModel requires numClasses >= 2," +
- s" but was given $numClasses")
- val rootNode = TreeBuilder.randomBalancedDecisionTree(depth = depth, labelType = numClasses,
- featureArity = featureArity, seed = seed)
- new DecisionTreeClassificationModel(rootNode, numFeatures = featureArity.length,
- numClasses = numClasses)
+ seed: Long
+ ): DecisionTreeClassificationModel = {
+ require(
+ numClasses >= 2,
+ s"DecisionTreeClassificationModel requires numClasses >= 2," +
+ s" but was given $numClasses"
+ )
+ val rootNode = TreeBuilder.randomBalancedDecisionTree(
+ depth = depth,
+ labelType = numClasses,
+ featureArity = featureArity,
+ seed = seed
+ )
+ new DecisionTreeClassificationModel(
+ rootNode,
+ numFeatures = featureArity.length,
+ numClasses = numClasses
+ )
}
def newDecisionTreeRegressionModel(
depth: Int,
featureArity: Array[Int],
- seed: Long): DecisionTreeRegressionModel = {
- val rootNode = TreeBuilder.randomBalancedDecisionTree(depth = depth, labelType = 0,
- featureArity = featureArity, seed = seed)
+ seed: Long
+ ): DecisionTreeRegressionModel = {
+ val rootNode = TreeBuilder.randomBalancedDecisionTree(
+ depth = depth,
+ labelType = 0,
+ featureArity = featureArity,
+ seed = seed
+ )
new DecisionTreeRegressionModel(rootNode, numFeatures = featureArity.length)
}
@@ -63,52 +79,46 @@ object ModelBuilderSSP {
model.set(model.modelType, "multinomial")
}
- def newLinearSVCModel(
- coefficients: Vector,
- intercept: Double): LinearSVCModel = {
+ def newLinearSVCModel(coefficients: Vector, intercept: Double): LinearSVCModel =
ClassificationModelBuilder.newLinearSVCModel(coefficients, intercept)
- }
}
-/**
- * Helpers for creating random decision trees.
- */
+/** Helpers for creating random decision trees.
+ */
object TreeBuilder {
- /**
- * Generator for a pair of distinct class labels from the set {0,...,numClasses-1}.
- * Pairs are useful for trees to make sure sibling leaf nodes make different predictions.
- * @param numClasses Number of classes.
- */
+ /** Generator for a pair of distinct class labels from the set {0,...,numClasses-1}. Pairs are
+ * useful for trees to make sure sibling leaf nodes make different predictions.
+ * @param numClasses
+ * Number of classes.
+ */
private class ClassLabelPairGenerator(val numClasses: Int)
- extends RandomDataGenerator[Pair[Double, Double]] {
+ extends RandomDataGenerator[Pair[Double, Double]] {
- require(numClasses >= 2,
- s"ClassLabelPairGenerator given label numClasses = $numClasses, but numClasses should be >= 2.")
+ require(
+ numClasses >= 2,
+ s"ClassLabelPairGenerator given label numClasses = $numClasses, but numClasses should be >= 2."
+ )
private val rng = new java.util.Random()
override def nextValue(): Pair[Double, Double] = {
- val left = rng.nextInt(numClasses)
+ val left = rng.nextInt(numClasses)
var right = rng.nextInt(numClasses)
- while (right == left) {
+ while (right == left)
right = rng.nextInt(numClasses)
- }
new Pair[Double, Double](left, right)
}
- override def setSeed(seed: Long): Unit = {
+ override def setSeed(seed: Long): Unit =
rng.setSeed(seed)
- }
override def copy(): ClassLabelPairGenerator = new ClassLabelPairGenerator(numClasses)
}
-
- /**
- * Generator for a pair of real-valued labels.
- * Pairs are useful for trees to make sure sibling leaf nodes make different predictions.
- */
+ /** Generator for a pair of real-valued labels. Pairs are useful for trees to make sure sibling
+ * leaf nodes make different predictions.
+ */
private class RealLabelPairGenerator() extends RandomDataGenerator[Pair[Double, Double]] {
private val rng = new java.util.Random()
@@ -116,34 +126,37 @@ object TreeBuilder {
override def nextValue(): Pair[Double, Double] =
new Pair[Double, Double](rng.nextDouble(), rng.nextDouble())
- override def setSeed(seed: Long): Unit = {
+ override def setSeed(seed: Long): Unit =
rng.setSeed(seed)
- }
override def copy(): RealLabelPairGenerator = new RealLabelPairGenerator()
}
- /**
- * Creates a random decision tree structure.
- * @param depth Depth of tree to build. Must be <= numFeatures.
- * @param labelType Value 0 indicates regression. Integers >= 2 indicate numClasses for
- * classification.
- * @param featureArity Array of length numFeatures indicating feature type.
- * Value 0 indicates continuous feature.
- * Other values >= 2 indicate a categorical feature,
- * where the value is the number of categories.
- * @return root node of tree
- */
+ /** Creates a random decision tree structure.
+ * @param depth
+ * Depth of tree to build. Must be <= numFeatures.
+ * @param labelType
+ * Value 0 indicates regression. Integers >= 2 indicate numClasses for classification.
+ * @param featureArity
+ * Array of length numFeatures indicating feature type. Value 0 indicates continuous feature.
+ * Other values >= 2 indicate a categorical feature, where the value is the number of
+ * categories.
+ * @return
+ * root node of tree
+ */
def randomBalancedDecisionTree(
depth: Int,
labelType: Int,
featureArity: Array[Int],
- seed: Long): Node = {
+ seed: Long
+ ): Node = {
require(depth >= 0, s"randomBalancedDecisionTree given depth < 0.")
val numFeatures = featureArity.length
- require(depth <= numFeatures,
+ require(
+ depth <= numFeatures,
s"randomBalancedDecisionTree requires depth <= featureArity.size," +
- s" but depth = $depth and featureArity.size = $numFeatures")
+ s" but depth = $depth and featureArity.size = $numFeatures"
+ )
val isRegression = labelType == 0
if (!isRegression) {
require(labelType >= 2, s"labelType must be >= 2 for classification. 0 indicates regression.")
@@ -165,29 +178,38 @@ object TreeBuilder {
ImpurityCalculator.getCalculator("gini", Array.fill[Double](labelType)(0.0), 0L)
}
- randomBalancedDecisionTreeHelper(depth, featureArity, impurityCalculator,
- labelGenerator, Set.empty, rng)
+ randomBalancedDecisionTreeHelper(
+ depth,
+ featureArity,
+ impurityCalculator,
+ labelGenerator,
+ Set.empty,
+ rng
+ )
}
- /**
- * Create an internal node. Either create the leaf nodes beneath it, or recurse as needed.
- * @param subtreeDepth Depth of subtree to build. Depth 0 means this is a leaf node.
- * @param featureArity Indicates feature type. Value 0 indicates continuous feature.
- * Other values >= 2 indicate a categorical feature,
- * where the value is the number of categories.
- * @param impurityCalculator Dummy impurity calculator to use at all tree nodes
- * @param usedFeatures Features appearing in the path from the tree root to the node
- * being constructed.
- * @param labelGenerator Generates pairs of distinct labels.
- * @return
- */
+ /** Create an internal node. Either create the leaf nodes beneath it, or recurse as needed.
+ * @param subtreeDepth
+ * Depth of subtree to build. Depth 0 means this is a leaf node.
+ * @param featureArity
+ * Indicates feature type. Value 0 indicates continuous feature. Other values >= 2 indicate a
+ * categorical feature, where the value is the number of categories.
+ * @param impurityCalculator
+ * Dummy impurity calculator to use at all tree nodes
+ * @param usedFeatures
+ * Features appearing in the path from the tree root to the node being constructed.
+ * @param labelGenerator
+ * Generates pairs of distinct labels.
+ * @return
+ */
private def randomBalancedDecisionTreeHelper(
subtreeDepth: Int,
featureArity: Array[Int],
impurityCalculator: ImpurityCalculator,
labelGenerator: RandomDataGenerator[Pair[Double, Double]],
usedFeatures: Set[Int],
- rng: scala.util.Random): Node = {
+ rng: scala.util.Random
+ ): Node = {
if (subtreeDepth == 0) {
// This case only happens for a depth 0 tree.
@@ -196,14 +218,16 @@ object TreeBuilder {
val numFeatures = featureArity.length
// Should not happen.
- assert(usedFeatures.size < numFeatures, s"randomBalancedDecisionTreeSplitNode ran out of " +
- s"features for splits.")
+ assert(
+ usedFeatures.size < numFeatures,
+ s"randomBalancedDecisionTreeSplitNode ran out of " +
+ s"features for splits."
+ )
// Make node internal.
var feature: Int = rng.nextInt(numFeatures)
- while (usedFeatures.contains(feature)) {
+ while (usedFeatures.contains(feature))
feature = rng.nextInt(numFeatures)
- }
val split: Split = if (featureArity(feature) == 0) {
// continuous feature
new ContinuousSplit(featureIndex = feature, threshold = rng.nextDouble())
@@ -213,27 +237,55 @@ object TreeBuilder {
// nCatsSplit is in {1,...,arity-1}.
val nCatsSplit = rng.nextInt(featureArity(feature) - 1) + 1
val splitCategories: Array[Double] =
- rng.shuffle(Range(0,featureArity(feature)).toList).toArray.map(_.toDouble).take(nCatsSplit)
- new CategoricalSplit(featureIndex = feature,
- _leftCategories = splitCategories, numCategories = featureArity(feature))
+ rng.shuffle(Range(0, featureArity(feature)).toList).toArray.map(_.toDouble).take(nCatsSplit)
+ new CategoricalSplit(
+ featureIndex = feature,
+ _leftCategories = splitCategories,
+ numCategories = featureArity(feature)
+ )
}
val (leftChild: Node, rightChild: Node) = if (subtreeDepth == 1) {
// Add leaf nodes. Assign these jointly so they make different predictions.
val predictions = labelGenerator.nextValue()
- val leftChild = new LeafNode(prediction = predictions._1, impurity = 0.0,
- impurityStats = impurityCalculator)
- val rightChild = new LeafNode(prediction = predictions._2, impurity = 0.0,
- impurityStats = impurityCalculator)
+ val leftChild = new LeafNode(
+ prediction = predictions._1,
+ impurity = 0.0,
+ impurityStats = impurityCalculator
+ )
+ val rightChild = new LeafNode(
+ prediction = predictions._2,
+ impurity = 0.0,
+ impurityStats = impurityCalculator
+ )
(leftChild, rightChild)
} else {
- val leftChild = randomBalancedDecisionTreeHelper(subtreeDepth - 1, featureArity,
- impurityCalculator, labelGenerator, usedFeatures + feature, rng)
- val rightChild = randomBalancedDecisionTreeHelper(subtreeDepth - 1, featureArity,
- impurityCalculator, labelGenerator, usedFeatures + feature, rng)
+ val leftChild = randomBalancedDecisionTreeHelper(
+ subtreeDepth - 1,
+ featureArity,
+ impurityCalculator,
+ labelGenerator,
+ usedFeatures + feature,
+ rng
+ )
+ val rightChild = randomBalancedDecisionTreeHelper(
+ subtreeDepth - 1,
+ featureArity,
+ impurityCalculator,
+ labelGenerator,
+ usedFeatures + feature,
+ rng
+ )
(leftChild, rightChild)
}
- new InternalNode(prediction = 0.0, impurity = 0.0, gain = 0.0, leftChild = leftChild,
- rightChild = rightChild, split = split, impurityStats = impurityCalculator)
+ new InternalNode(
+ prediction = 0.0,
+ impurity = 0.0,
+ gain = 0.0,
+ leftChild = leftChild,
+ rightChild = rightChild,
+ split = split,
+ impurityStats = impurityCalculator
+ )
}
}
diff --git a/src/main/scala/org/apache/spark/ml/TreeUtils.scala b/src/main/scala/org/apache/spark/ml/TreeUtils.scala
index badef4fd..2dc1dcb5 100644
--- a/src/main/scala/org/apache/spark/ml/TreeUtils.scala
+++ b/src/main/scala/org/apache/spark/ml/TreeUtils.scala
@@ -4,26 +4,28 @@ import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericA
import org.apache.spark.sql.DataFrame
object TreeUtils {
- /**
- * Set label metadata (particularly the number of classes) on a DataFrame.
- *
- * @param data Dataset. Categorical features and labels must already have 0-based indices.
- * This must be non-empty.
- * @param featuresColName Name of the features column
- * @param featureArity Array of length numFeatures, where 0 indicates continuous feature and
- * value > 0 indicates a categorical feature of that arity.
- * @return DataFrame with metadata
- */
- def setMetadata(
- data: DataFrame,
- featuresColName: String,
- featureArity: Array[Int]): DataFrame = {
- val featuresAttributes = featureArity.zipWithIndex.map { case (arity: Int, feature: Int) =>
- if (arity > 0) {
- NominalAttribute.defaultAttr.withIndex(feature).withNumValues(arity)
- } else {
- NumericAttribute.defaultAttr.withIndex(feature)
- }
+
+ /** Set label metadata (particularly the number of classes) on a DataFrame.
+ *
+ * @param data
+ * Dataset. Categorical features and labels must already have 0-based indices. This must be
+ * non-empty.
+ * @param featuresColName
+ * Name of the features column
+ * @param featureArity
+ * Array of length numFeatures, where 0 indicates continuous feature and value > 0 indicates a
+ * categorical feature of that arity.
+ * @return
+ * DataFrame with metadata
+ */
+ def setMetadata(data: DataFrame, featuresColName: String, featureArity: Array[Int]): DataFrame = {
+ val featuresAttributes = featureArity.zipWithIndex.map {
+ case (arity: Int, feature: Int) =>
+ if (arity > 0) {
+ NominalAttribute.defaultAttr.withIndex(feature).withNumValues(arity)
+ } else {
+ NumericAttribute.defaultAttr.withIndex(feature)
+ }
}
val featuresMetadata = new AttributeGroup("features", featuresAttributes).toMetadata()
data.select(data(featuresColName).as(featuresColName, featuresMetadata))
diff --git a/src/main/scala/org/apache/spark/ml/classification/ClassificationModelBuilder.scala b/src/main/scala/org/apache/spark/ml/classification/ClassificationModelBuilder.scala
index 485b883f..00fe5155 100644
--- a/src/main/scala/org/apache/spark/ml/classification/ClassificationModelBuilder.scala
+++ b/src/main/scala/org/apache/spark/ml/classification/ClassificationModelBuilder.scala
@@ -2,12 +2,8 @@ package org.apache.spark.ml.classification
import org.apache.spark.ml.linalg.{Matrix, Vector}
-
object ClassificationModelBuilder {
- def newLinearSVCModel(
- coefficients: Vector,
- intercept: Double): LinearSVCModel = {
+ def newLinearSVCModel(coefficients: Vector, intercept: Double): LinearSVCModel =
new LinearSVCModel("linearSVC", coefficients, intercept)
- }
}
diff --git a/version.sbt b/version.sbt
index 7338ce76..f13c2095 100644
--- a/version.sbt
+++ b/version.sbt
@@ -1 +1 @@
-version in ThisBuild := "0.5.1-SNAPSHOT"
+ThisBuild / version := "0.5.2-SNAPSHOT"