Spark数据输入--DStream生成

2022-05-30 00:00:00 数据 代码 调用 方法 复制

之前文章Spark数据输入--RDD生成里讲过,Spark的数据输入,主要有三类,这里我们开始研究一下其中的第二类,即加载流数据,生成DStream。

这里以生产上常用的Spark读取Kafka为例,先来一段代码:

        Map<String, Object> kafkaParameters = new HashMap<>();
        kafkaParameters.put("bootstrap.servers", bootstrapServers);
        kafkaParameters.put("group.id", group);
        kafkaParameters.put("auto.offset.reset", reset);
        kafkaParameters.put("key.deserializer", "org.apache.kafka.common.serialization.StringDeserializer");
        kafkaParameters.put("value.deserializer", "org.apache.kafka.common.serialization.StringDeserializer");
        kafkaParameters.put("enable.auto.commit", "" + autoCommit);
        EngineParamUtils.setDynamicParamByKafka(kafkaParameters,dynamicParam);
        Collection<String> topicsSet = Arrays.asList(topics.split(","));

        JavaInputDStream<ConsumerRecord<String, String>> directStream = KafkaUtils.createDirectStream(
                jsc,
                LocationStrategies.PreferConsistent(),
                ConsumerStrategies.Subscribe(topicsSet, kafkaParameters)
        );
    }
复制代码
  /**
   * :: Experimental ::
   * Scala constructor for a DStream where
   * each given Kafka topic/partition corresponds to an RDD partition.
   * The spark configuration spark.streaming.kafka.maxRatePerPartition gives the maximum number
   *  of messages
   * per second that each '''partition''' will accept.
   * @param locationStrategy In most cases, pass in [[LocationStrategies.PreferConsistent]],
   *   see [[LocationStrategies]] for more details.
   * @param consumerStrategy In most cases, pass in [[ConsumerStrategies.Subscribe]],
   *   see [[ConsumerStrategies]] for more details
   * @tparam K type of Kafka message key
   * @tparam V type of Kafka message value
   */
  @Experimental
  def createDirectStream[K, V](
      ssc: StreamingContext,
      locationStrategy: LocationStrategy,
      consumerStrategy: ConsumerStrategy[K, V]
    ): InputDStream[ConsumerRecord[K, V]] = {
    val ppc = new DefaultPerPartitionConfig(ssc.sparkContext.getConf)
    createDirectStream[K, V](ssc, locationStrategy, consumerStrategy, ppc)
  }

  /**
   * :: Experimental ::
   * Scala constructor for a DStream where
   * each given Kafka topic/partition corresponds to an RDD partition.
   * @param locationStrategy In most cases, pass in [[LocationStrategies.PreferConsistent]],
   *   see [[LocationStrategies]] for more details.
   * @param consumerStrategy In most cases, pass in [[ConsumerStrategies.Subscribe]],
   *   see [[ConsumerStrategies]] for more details.
   * @param perPartitionConfig configuration of settings such as max rate on a per-partition basis.
   *   see [[PerPartitionConfig]] for more details.
   * @tparam K type of Kafka message key
   * @tparam V type of Kafka message value
   */
  @Experimental
  def createDirectStream[K, V](
      ssc: StreamingContext,
      locationStrategy: LocationStrategy,
      consumerStrategy: ConsumerStrategy[K, V],
      perPartitionConfig: PerPartitionConfig
    ): InputDStream[ConsumerRecord[K, V]] = {
    new DirectKafkaInputDStream[K, V](ssc, locationStrategy, consumerStrategy, perPartitionConfig)
  }

复制代码

主要是使用了KafkaUtils工具类的createDirectStream方法,创建出DStream。 先介绍一下这个方法的参数,(一)ssc即StreamingContext,是Spark Streaming上下文类,这个类,提供了比较多用来创建DStream的方法,以及相关转换方法。

(二)locationStrategy即LocationStrategy,SparkStreaming 读取 kafka topic的本地策略类,包括1)PreferBrokers优先Broker,可用在当executor和Kafka的Broker在同一台服务器上的时候减少网络传输,2)PreferConsistent,一致性策略,即一致地将分区分散给executor,3)PreferFixed,固定策略,即固定主机读取topic分区。

(三)ConsumerStrategy,消费策略,有订阅指定Topic模式(可用订阅多个)、订阅符合某种pattern的topic模式、分配指定topic列表模式

(四)PerPartitionConfig,分区配置,主要包括spark.streaming.kafka.maxRatePerPartition每分区大消费速率和spark.streaming.kafka.minRatePerPartition每分区小消费速率两个参数。

以上四个参数,都用于构建DirectKafkaInputDStream对象,作为该对象的属性值。

从上面源码可以看出,createDirectStream,其实也只是构建出了DStream对象,这个和加载数据生成RDD类型都是懒加载,这里还没有真正触发去从kafka读取数据的动作。

那什么时候才是真正执行读取kafka topic数据呢?我们去看一下DirectKafkaInputDStream这个类。

/**
 *  A DStream where
 * each given Kafka topic/partition corresponds to an RDD partition.
 * The spark configuration spark.streaming.kafka.maxRatePerPartition gives the maximum number
 *  of messages
 * per second that each '''partition''' will accept.
 * @param locationStrategy In most cases, pass in [[LocationStrategies.PreferConsistent]],
 *   see [[LocationStrategy]] for more details.
 * @param consumerStrategy In most cases, pass in [[ConsumerStrategies.Subscribe]],
 *   see [[ConsumerStrategy]] for more details
 * @param ppc configuration of settings such as max rate on a per-partition basis.
 *   see [[PerPartitionConfig]] for more details.
 * @tparam K type of Kafka message key
 * @tparam V type of Kafka message value
 */
private[spark] class DirectKafkaInputDStream[K, V](
    _ssc: StreamingContext,
    locationStrategy: LocationStrategy,
    consumerStrategy: ConsumerStrategy[K, V],
    ppc: PerPartitionConfig
  ) extends InputDStream[ConsumerRecord[K, V]](_ssc) with Logging with CanCommitOffsets
复制代码

从类注释上可以看到,这个类,是将每个kafka topic的分区直接映射为一个RDD的分区。

这个类继承了InputDStream类,所以我们先看下这个InputDstream类。

InputDStream是所有输入流的基础抽象类,提供了两个需要子类实现的方法,即start和stop方法,源码注释上讲start方法是去读取数据,stop方法是停止读取数据。似乎真正读取数据的就是在start方法里面,所以我们回到DirectKafkaInputDStream这个类的start方法里面。

  override def start(): Unit = {
    val c = consumer
    paranoidPoll(c)
    if (currentOffsets.isEmpty) {
      currentOffsets = c.assignment().asScala.map { tp =>
        tp -> c.position(tp)
      }.toMap
    }
  }

  override def stop(): Unit = this.synchronized {
    if (kc != null) {
      kc.close()
    }
  }
复制代码

从这里可以看到,DirectKafkaInputDStream类的start方法,仅仅是次去获取Kafka topic当前消费组的offset位置,stop方法,是将Kafka Consumer对象关闭然后释放资源。追踪一下start和stop方法被调用情况发现,这两个方法是在DStreamGraph类中,执行start的时候被调用一次。

org.apache.spark.streaming.DStreamGraph

  def start(time: Time) {
    this.synchronized {
      require(zeroTime == null, "DStream graph computation already started")
      zeroTime = time
      startTime = time
      outputStreams.foreach(_.initialize(zeroTime))
      outputStreams.foreach(_.remember(rememberDuration))
      outputStreams.foreach(_.validateAtStart())
      numReceivers = inputStreams.count(_.isInstanceOf[ReceiverInputDStream[_]])
      inputStreamNameAndID = inputStreams.map(is => (is.name, is.id))
      inputStreams.par.foreach(_.start())
    }
  }
复制代码

因此start方法仅仅只是用于启动时初始化某些参数的,比如DirectKafkaInputDStream就是用于获取消费组上次消费停止的offset位置。stop方法用于释放资源的。

我们从流任务启动入口去找,即StreamingContext的start方法。

/**
   * Start the execution of the streams.
   *
   * @throws IllegalStateException if the StreamingContext is already stopped.
   */
  def start(): Unit = synchronized {
    state match {
      case INITIALIZED =>
        startSite.set(DStream.getCreationSite())
        StreamingContext.ACTIVATION_LOCK.synchronized {
          StreamingContext.assertNoOtherContextIsActive()
          try {
            validate()

            // Start the streaming scheduler in a new thread, so that thread local properties
            // like call sites and job groups can be reset without affecting those of the
            // current thread.
            ThreadUtils.runInNewThread("streaming-start") {
              sparkContext.setCallSite(startSite.get)
              sparkContext.clearJobGroup()
              sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false")
              savedProperties.set(SerializationUtils.clone(sparkContext.localProperties.get()))
              scheduler.start()
            }
            state = StreamingContextState.ACTIVE
            scheduler.listenerBus.post(
              StreamingListenerStreamingStarted(System.currentTimeMillis()))
          } catch {
            case NonFatal(e) =>
              logError("Error starting the context, marking it as stopped", e)
              scheduler.stop(false)
              state = StreamingContextState.STOPPED
              throw e
          }
          StreamingContext.setActiveContext(this)
        }
        logDebug("Adding shutdown hook") // force eager creation of logger
        shutdownHookRef = ShutdownHookManager.addShutdownHook(
          StreamingContext.SHUTDOWN_HOOK_PRIORITY)(() => stopOnShutdown())
        // Registering Streaming Metrics at the start of the StreamingContext
        assert(env.metricsSystem != null)
        env.metricsSystem.registerSource(streamingSource)
        uiTab.foreach(_.attach())
        logInfo("StreamingContext started")
      case ACTIVE =>
        logWarning("StreamingContext has already been started")
      case STOPPED =>
        throw new IllegalStateException("StreamingContext has already been stopped")
    }
  }
复制代码

其中里面主要的是scheduler.start(),追踪调用链如下:

其中graph.start(startTime - graph.batchDuration),就是调用DStreamGraph的start方法,上面已经分析过,会调用到InputDStream的start方法。 主要核心的是timer.start(startTime.milliseconds),我们先看下这个timer是什么对象。

  private val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds,
    longTime => eventLoop.post(GenerateJobs(new Time(longTime))), "JobGenerator")
复制代码

这是一个定时器,会按照指定间隔时长执行传入的函数。

private[streaming]
class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: String)
  extends Logging
复制代码

其中callback,就是longTime => eventLoop.post(GenerateJobs(new Time(longTime)),就是按间隔时长,将生成的GenerateJobs事件发送到eventLoop时间循环处理器上。

回到JobGenerator的start方法,里面就初始化了一个EventLoop对象:

    eventLoop = new EventLoop[JobGeneratorEvent]("JobGenerator") {
      override protected def onReceive(event: JobGeneratorEvent): Unit = processEvent(event)

      override protected def onError(e: Throwable): Unit = {
        jobScheduler.reportError("Error in job generator", e)
      }
    }
复制代码

主要看下processEvent方法:

  /** Processes all events */
  private def processEvent(event: JobGeneratorEvent) {
    logDebug("Got event " + event)
    event match {
      case GenerateJobs(time) => generateJobs(time)
      case ClearMetadata(time) => clearMetadata(time)
      case DoCheckpoint(time, clearCheckpointDataLater) =>
        doCheckpoint(time, clearCheckpointDataLater)
      case ClearCheckpointData(time) => clearCheckpointData(time)
    }
  }
复制代码

对于GenerateJobs事件,会调用generateJobs方法:

/** Generate jobs and perform checkpointing for the given `time`.  */
  private def generateJobs(time: Time) {
    // Checkpoint all RDDs marked for checkpointing to ensure their lineages are
    // truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847).
    ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, "true")
    Try {
      jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate received blocks to batch
      graph.generateJobs(time) // generate jobs using allocated block
    } match {
      case Success(jobs) =>
        val streamIdToInputInfos = jobScheduler.inputInfoTracker.getInfo(time)
        jobScheduler.submitJobSet(JobSet(time, jobs, streamIdToInputInfos))
      case Failure(e) =>
        jobScheduler.reportError("Error generating jobs for time " + time, e)
        PythonDStream.stopStreamingContextIfPythonProcessIsDead(e)
    }
    eventLoop.post(DoCheckpoint(time, clearCheckpointDataLater = false))
  }
复制代码

其中的graph.generateJobs(time),就是调用DStreamGraph的generateJobs方法,生成job。

org.apache.spark.streaming.DStreamGraph

  def generateJobs(time: Time): Seq[Job] = {
    logDebug("Generating jobs for time " + time)
    val jobs = this.synchronized {
      outputStreams.flatMap { outputStream =>
        val jobOption = outputStream.generateJob(time)
        jobOption.foreach(_.setCallSite(outputStream.creationSite))
        jobOption
      }
    }
    logDebug("Generated " + jobs.length + " jobs for time " + time)
    jobs
  }
复制代码

从这里可以看到,是以outputStreams作为引子,切入进去,生成job的,即outputStream.generateJob(time)

  /**
   * Generate a SparkStreaming job for the given time. This is an internal method that
   * should not be called directly. This default implementation creates a job
   * that materializes the corresponding RDD. Subclasses of DStream may override this
   * to generate their own jobs.
   */
  private[streaming] def generateJob(time: Time): Option[Job] = {
    getOrCompute(time) match {
      case Some(rdd) =>
        val jobFunc = () => {
          val emptyFunc = { (iterator: Iterator[T]) => {} }
          context.sparkContext.runJob(rdd, emptyFunc)
        }
        Some(new Job(time, jobFunc))
      case None => None
    }
  }
复制代码

终调用了DStream的getOrCompute方法:

/**
   * Get the RDD corresponding to the given time; either retrieve it from cache
   * or compute-and-cache it.
   */
  private[streaming] final def getOrCompute(time: Time): Option[RDD[T]] = {
    // If RDD was already generated, then retrieve it from HashMap,
    // or else compute the RDD
    generatedRDDs.get(time).orElse {
      // Compute the RDD if time is valid (e.g. correct time in a sliding window)
      // of RDD generation, else generate nothing.
      if (isTimeValid(time)) {

        val rddOption = createRDDWithLocalProperties(time, displayInnerRDDOps = false) {
          // Disable checks for existing output directories in jobs launched by the streaming
          // scheduler, since we may need to write output to an existing directory during checkpoint
          // recovery; see SPARK-4835 for more details. We need to have this call here because
          // compute() might cause Spark jobs to be launched.
          SparkHadoopWriterUtils.disableOutputSpecValidation.withValue(true) {
            compute(time)
          }
        }

        rddOption.foreach { case newRDD =>
          // Register the generated RDD for caching and checkpointing
          if (storageLevel != StorageLevel.NONE) {
            newRDD.persist(storageLevel)
            logDebug(s"Persisting RDD ${newRDD.id} for time $time to $storageLevel")
          }
          if (checkpointDuration != null && (time - zeroTime).isMultipleOf(checkpointDuration)) {
            newRDD.checkpoint()
            logInfo(s"Marking RDD ${newRDD.id} for time $time for checkpointing")
          }
          generatedRDDs.put(time, newRDD)
        }
        rddOption
      } else {
        None
      }
    }
  }
复制代码

在getOrCompute方法中,发现终调用的是DStream的compute方法生成RDD。

所以我们回到DirectKafkaInputDStream的compute方法:

  override def compute(validTime: Time): Option[KafkaRDD[K, V]] = {
    val untilOffsets = clamp(latestOffsets())
    val offsetRanges = untilOffsets.map { case (tp, uo) =>
      val fo = currentOffsets(tp)
      OffsetRange(tp.topic, tp.partition, fo, uo)
    }
    val useConsumerCache = context.conf.getBoolean("spark.streaming.kafka.consumer.cache.enabled",
      true)
    val rdd = new KafkaRDD[K, V](context.sparkContext, executorKafkaParams, offsetRanges.toArray,
      getPreferredHosts, useConsumerCache)

    // Report the record number and metadata of this batch interval to InputInfoTracker.
    val description = offsetRanges.filter { offsetRange =>
      // Don't display empty ranges.
      offsetRange.fromOffset != offsetRange.untilOffset
    }.map { offsetRange =>
      s"topic: ${offsetRange.topic}\tpartition: ${offsetRange.partition}\t" +
        s"offsets: ${offsetRange.fromOffset} to ${offsetRange.untilOffset}"
    }.mkString("\n")
    // Copy offsetRanges to immutable.List to prevent from being modified by the user
    val metadata = Map(
      "offsets" -> offsetRanges.toList,
      StreamInputInfo.METADATA_KEY_DESCRIPTION -> description)
    val inputInfo = StreamInputInfo(id, rdd.count, metadata)
    ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo)

    currentOffsets = untilOffsets
    commitAll()
    Some(rdd)
  }
复制代码

在这里生成了KafkaRDD。 所以结合我们之前RDD生成的分析,终加载数据的地方就是KafkaRDD的compute方法:

override def compute(thePart: Partition, context: TaskContext): Iterator[ConsumerRecord[K, V]] = {
    val part = thePart.asInstanceOf[KafkaRDDPartition]
    require(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part))
    if (part.fromOffset == part.untilOffset) {
      logInfo(s"Beginning offset ${part.fromOffset} is the same as ending offset " +
        s"skipping ${part.topic} ${part.partition}")
      Iterator.empty
    } else {
      logInfo(s"Computing topic ${part.topic}, partition ${part.partition} " +
        s"offsets ${part.fromOffset} -> ${part.untilOffset}")
      if (compacted) {
        new CompactedKafkaRDDIterator[K, V](
          part,
          context,
          kafkaParams,
          useConsumerCache,
          pollTimeout,
          cacheInitialCapacity,
          cacheMaxCapacity,
          cacheLoadFactor
        )
      } else {
        new KafkaRDDIterator[K, V](
          part,
          context,
          kafkaParams,
          useConsumerCache,
          pollTimeout,
          cacheInitialCapacity,
          cacheMaxCapacity,
          cacheLoadFactor
        )
      }
    }
  }
复制代码

后是在KafkaRDDIterator或者CompactedKafkaRDDIterator里面从Kafka上获取数据。

整个流程如下:

相关文章