SparkSQL Join的源码分析(spark sql工作原理)

SparkSQL Join的源码分析(spark sql工作原理)

精选文章moguli202025-05-02 18:32:075A+A-



为了更深入理解SparkSQL Join的实现原理,可以分析其源码。以下是SparkSQL Join的源码分析:

1. Join策略选择

SparkSQL在
org.apache.spark.sql.execution.joins包中实现了各种Join策略。在Join类的doExecute方法中,会根据统计信息和配置选择合适的Join策略。

def doExecute(): RDD[InternalRow] = {
  val leftKeys = leftKeysArray
  val rightKeys = rightKeysArray
  if (joinType == JoinType.CROSS) {
    CrossHashJoin.doJoin(left, right, leftKeys, rightKeys, joinType, condition, leftFilters, rightFilters)
  } else {
    if (left.output.size > 0 && right.output.size > 0) {
      leftKeys.length match {
        case 0 =>
          // Cartesian product
          CartesianProduct.doJoin(left, right, joinType, condition, leftFilters, rightFilters)
        case 1 =>
          // Single key, use hash join
          if (joinType == JoinType.INNER || joinType == JoinType.CROSS) {
            HashJoin.doJoin(left, right, leftKeys, rightKeys, joinType, condition, leftFilters, rightFilters)
          } else {
            // For outer joins, use sort merge join to preserve the order
            SortMergeJoin.doJoin(left, right, leftKeys, rightKeys, joinType, condition, leftFilters, rightFilters)
          }
        case _ =>
          // Multiple keys, use sort merge join
          SortMergeJoin.doJoin(left, right, leftKeys, rightKeys, joinType, condition, leftFilters, rightFilters)
      }
    } else {
      // One of the children has no output, return empty
      RDD.empty[InternalRow](sparkContext)
    }
  }
}

2. Hash Join实现

Hash Join的实现主要在HashJoin类中。以下是Hash Join的主要实现步骤:

  1. 选择构建侧和Probe侧:根据统计信息选择较小的表作为构建侧
  2. 构建Hash表:将构建侧的数据按照Join键构建Hash表
  3. Probe阶段:将Probe侧的数据按照Join键进行查找
  4. 连接操作:根据Join类型(内连接、外连接等)进行相应的连接操作
object HashJoin {
  def doJoin(
      left: RDD[InternalRow],
      right: RDD[InternalRow],
      leftKeys: Array[Expression],
      rightKeys: Array[Expression],
      joinType: JoinType,
      condition: Option[Expression],
      leftFilters: Option[Expression],
      rightFilters: Option[Expression]): RDD[InternalRow] = {
    // 选择构建侧和Probe侧
    val (buildSide, probeSide) = chooseSides(left, right)
    val (buildKeys, probeKeys) = if (buildSide == BuildSide.LEFT) {
      (leftKeys, rightKeys)
    } else {
      (rightKeys, leftKeys)
    }
    // 构建Hash表
    val buildRDD = buildSide match {
      case BuildSide.LEFT =>
        left.mapPartitions(iter => {
          val keyToRows = new mutable.HashMap[Any, mutable.Buffer[InternalRow]]()
          iter.foreach(row => {
            val key = leftKeys.map(_.eval(row)).toArray
            keyToRows.getOrElseUpdate(key, new mutable.ArrayBuffer[InternalRow]()) += row
          })
          iter ++ keyToRows.values.flatten
        })
      case BuildSide.RIGHT =>
        right.mapPartitions(iter => {
          val keyToRows = new mutable.HashMap[Any, mutable.Buffer[InternalRow]]()
          iter.foreach(row => {
            val key = rightKeys.map(_.eval(row)).toArray
            keyToRows.getOrElseUpdate(key, new mutable.ArrayBuffer[InternalRow]()) += row
          })
          iter ++ keyToRows.values.flatten
        })
    }
    // Probe阶段
    val probeRDD = probeSide match {
      case BuildSide.LEFT =>
        right.mapPartitions(iter => {
          val keyToRows = new mutable.HashMap[Any, mutable.Buffer[InternalRow]]()
          iter.foreach(row => {
            val key = rightKeys.map(_.eval(row)).toArray
            keyToRows.getOrElseUpdate(key, new mutable.ArrayBuffer[InternalRow]()) += row
          })
          iter ++ keyToRows.values.flatten
        })
      case BuildSide.RIGHT =>
        left.mapPartitions(iter => {
          val keyToRows = new mutable.HashMap[Any, mutable.Buffer[InternalRow]]()
          iter.foreach(row => {
            val key = leftKeys.map(_.eval(row)).toArray
            keyToRows.getOrElseUpdate(key, new mutable.ArrayBuffer[InternalRow]()) += row
          })
          iter ++ keyToRows.values.flatten
        })
    }
    // 连接操作
    probeRDD.join(buildRDD).mapPartitions(iter => {
      iter.flatMap { case (key, (probeRow, buildRow)) =>
        // 根据Join类型进行连接操作
        joinType match {
          case JoinType.INNER =>
            if (condition.map(_.eval(probeRow, buildRow)).getOrElse(true)) {
              Some(InternalRow.fromSeq(probeRow ++ buildRow))
            } else {
              None
            }
          case JoinType.LEFT =>
            if (condition.map(_.eval(probeRow, buildRow)).getOrElse(true)) {
              Some(InternalRow.fromSeq(probeRow ++ buildRow))
            } else {
              Some(InternalRow.fromSeq(probeRow ++ Seq.fill(buildRow.length)(null)))
            }
          case JoinType.RIGHT =>
            if (condition.map(_.eval(probeRow, buildRow)).getOrElse(true)) {
              Some(InternalRow.fromSeq(Seq.fill(probeRow.length)(null) ++ buildRow))
            } else {
              Some(InternalRow.fromSeq(Seq.fill(probeRow.length)(null) ++ buildRow))
            }
          case JoinType.FULL =>
            if (condition.map(_.eval(probeRow, buildRow)).getOrElse(true)) {
              Some(InternalRow.fromSeq(probeRow ++ buildRow))
            } else {
              Some(InternalRow.fromSeq(probeRow ++ Seq.fill(buildRow.length)(null)))
              Some(InternalRow.fromSeq(Seq.fill(probeRow.length)(null) ++ buildRow))
            }
        }
      }
    })
  }
}

3. Sort Merge Join实现

Sort Merge Join的实现主要在SortMergeJoin类中。以下是Sort Merge Join的主要实现步骤:

  1. 排序:对两个表按照Join键进行排序
  2. 合并:使用双指针技术合并两个排序后的数据集
  3. 连接操作:根据Join类型进行连接操作
object SortMergeJoin {
  def doJoin(
      left: RDD[InternalRow],
      right: RDD[InternalRow],
      leftKeys: Array[Expression],
      rightKeys: Array[Expression],
      joinType: JoinType,
      condition: Option[Expression],
      leftFilters: Option[Expression],
      rightFilters: Option[Expression]): RDD[InternalRow] = {
    // 排序
    val sortedLeft = left.sortBy(row => leftKeys.map(_.eval(row)).toArray)
    val sortedRight = right.sortBy(row => rightKeys.map(_.eval(row)).toArray)
    // 合并
    sortedLeft.zip(sortedRight).mapPartitions(iter => {
      val leftIter = iter.map(_._1).iterator
      val rightIter = iter.map(_._2).iterator
      val leftRow = new mutable.ArrayBuffer[InternalRow]()
      val rightRow = new mutable.ArrayBuffer[InternalRow]()
      while (leftIter.hasNext && rightIter.hasNext) {
        val l = leftIter.next()
        val r = rightIter.next()
        val lKey = leftKeys.map(_.eval(l)).toArray
        val rKey = rightKeys.map(_.eval(r)).toArray
        if (lKey < rKey) {
          leftRow += l
        } else if (lKey > rKey) {
          rightRow += r
        } else {
          // Join键相等,进行连接操作
          if (condition.map(_.eval(l, r)).getOrElse(true)) {
            yield JoinedRow(l, r)
          }
          // 处理重复键
          while (leftIter.hasNext && leftKeys.map(_.eval(leftIter.head)).toArray == lKey) {
            leftRow += leftIter.next()
          }
          while (rightIter.hasNext && rightKeys.map(_.eval(rightIter.head)).toArray == rKey) {
            rightRow += rightIter.next()
          }
          // 生成所有可能的组合
          for (l <- leftRow; r <- rightRow) {
            if (condition.map(_.eval(l, r)).getOrElse(true)) {
              yield JoinedRow(l, r)
            }
          }
          leftRow.clear()
          rightRow.clear()
        }
      }
      // 处理剩余的行
      while (leftIter.hasNext) {
        leftRow += leftIter.next()
      }
      while (rightIter.hasNext) {
        rightRow += rightIter.next()
      }
      // 根据Join类型处理剩余的行
      joinType match {
        case JoinType.INNER =>
          // 不需要处理剩余的行
        case JoinType.LEFT =>
          for (l <- leftRow) {
            if (leftFilters.map(_.eval(l)).getOrElse(true)) {
              yield JoinedRow(l, null)
            }
          }
        case JoinType.RIGHT =>
          for (r <- rightRow) {
            if (rightFilters.map(_.eval(r)).getOrElse(true)) {
              yield JoinedRow(null, r)
            }
          }
        case JoinType.FULL =>
          for (l <- leftRow) {
            if (leftFilters.map(_.eval(l)).getOrElse(true)) {
              yield JoinedRow(l, null)
            }
          }
          for (r <- rightRow) {
            if (rightFilters.map(_.eval(r)).getOrElse(true)) {
              yield JoinedRow(null, r)
            }
          }
      }
    })
  }
}

总结

本报告详细介绍了SparkSQL中Join的实现方式,包括Broadcast Join、Hash Join(包括Shuffle Hash Join)和Sort Merge Join。通过分析它们的实现原理、工作流程和适用场景,我们可以更好地理解SparkSQL中Join操作的内部机制。 在实际应用中,选择合适的Join策略对于提高SparkSQL查询性能至关重要。根据表的大小、数据分布和内存资源选择合适的Join策略,可以显著提高Join操作的性能。 通过深入理解SparkSQL Join的实现原理,我们可以更好地优化SparkSQL查询,提高大数据处理的效率和性能。

点击这里复制本文地址 以上内容由莫古技术网整理呈现,请务必在转载分享时注明本文地址!如对内容有疑问,请联系我们,谢谢!
qrcode

莫古技术网 © All Rights Reserved.  滇ICP备2024046894号-2