Spark数据输入--DStream生成
之前文章Spark数据输入--RDD生成里讲过,Spark的数据输入,主要有三类,这里我们开始研究一下其中的第二类,即加载流数据,生成DStream。
这里以生产上常用的Spark读取Kafka为例,先来一段代码:
Map<String, Object> kafkaParameters = new HashMap<>();
kafkaParameters.put("bootstrap.servers", bootstrapServers);
kafkaParameters.put("group.id", group);
kafkaParameters.put("auto.offset.reset", reset);
kafkaParameters.put("key.deserializer", "org.apache.kafka.common.serialization.StringDeserializer");
kafkaParameters.put("value.deserializer", "org.apache.kafka.common.serialization.StringDeserializer");
kafkaParameters.put("enable.auto.commit", "" + autoCommit);
EngineParamUtils.setDynamicParamByKafka(kafkaParameters,dynamicParam);
Collection<String> topicsSet = Arrays.asList(topics.split(","));
JavaInputDStream<ConsumerRecord<String, String>> directStream = KafkaUtils.createDirectStream(
jsc,
LocationStrategies.PreferConsistent(),
ConsumerStrategies.Subscribe(topicsSet, kafkaParameters)
);
}
复制代码
/**
* :: Experimental ::
* Scala constructor for a DStream where
* each given Kafka topic/partition corresponds to an RDD partition.
* The spark configuration spark.streaming.kafka.maxRatePerPartition gives the maximum number
* of messages
* per second that each '''partition''' will accept.
* @param locationStrategy In most cases, pass in [[LocationStrategies.PreferConsistent]],
* see [[LocationStrategies]] for more details.
* @param consumerStrategy In most cases, pass in [[ConsumerStrategies.Subscribe]],
* see [[ConsumerStrategies]] for more details
* @tparam K type of Kafka message key
* @tparam V type of Kafka message value
*/
@Experimental
def createDirectStream[K, V](
ssc: StreamingContext,
locationStrategy: LocationStrategy,
consumerStrategy: ConsumerStrategy[K, V]
): InputDStream[ConsumerRecord[K, V]] = {
val ppc = new DefaultPerPartitionConfig(ssc.sparkContext.getConf)
createDirectStream[K, V](ssc, locationStrategy, consumerStrategy, ppc)
}
/**
* :: Experimental ::
* Scala constructor for a DStream where
* each given Kafka topic/partition corresponds to an RDD partition.
* @param locationStrategy In most cases, pass in [[LocationStrategies.PreferConsistent]],
* see [[LocationStrategies]] for more details.
* @param consumerStrategy In most cases, pass in [[ConsumerStrategies.Subscribe]],
* see [[ConsumerStrategies]] for more details.
* @param perPartitionConfig configuration of settings such as max rate on a per-partition basis.
* see [[PerPartitionConfig]] for more details.
* @tparam K type of Kafka message key
* @tparam V type of Kafka message value
*/
@Experimental
def createDirectStream[K, V](
ssc: StreamingContext,
locationStrategy: LocationStrategy,
consumerStrategy: ConsumerStrategy[K, V],
perPartitionConfig: PerPartitionConfig
): InputDStream[ConsumerRecord[K, V]] = {
new DirectKafkaInputDStream[K, V](ssc, locationStrategy, consumerStrategy, perPartitionConfig)
}
复制代码
主要是使用了KafkaUtils工具类的createDirectStream方法,创建出DStream。 先介绍一下这个方法的参数,(一)ssc即StreamingContext,是Spark Streaming上下文类,这个类,提供了比较多用来创建DStream的方法,以及相关转换方法。
(二)locationStrategy即LocationStrategy,SparkStreaming 读取 kafka topic的本地策略类,包括1)PreferBrokers优先Broker,可用在当executor和Kafka的Broker在同一台服务器上的时候减少网络传输,2)PreferConsistent,一致性策略,即一致地将分区分散给executor,3)PreferFixed,固定策略,即固定主机读取topic分区。
(三)ConsumerStrategy,消费策略,有订阅指定Topic模式(可用订阅多个)、订阅符合某种pattern的topic模式、分配指定topic列表模式
(四)PerPartitionConfig,分区配置,主要包括spark.streaming.kafka.maxRatePerPartition每分区大消费速率和spark.streaming.kafka.minRatePerPartition每分区小消费速率两个参数。
以上四个参数,都用于构建DirectKafkaInputDStream对象,作为该对象的属性值。
从上面源码可以看出,createDirectStream,其实也只是构建出了DStream对象,这个和加载数据生成RDD类型都是懒加载,这里还没有真正触发去从kafka读取数据的动作。
那什么时候才是真正执行读取kafka topic数据呢?我们去看一下DirectKafkaInputDStream这个类。
/**
* A DStream where
* each given Kafka topic/partition corresponds to an RDD partition.
* The spark configuration spark.streaming.kafka.maxRatePerPartition gives the maximum number
* of messages
* per second that each '''partition''' will accept.
* @param locationStrategy In most cases, pass in [[LocationStrategies.PreferConsistent]],
* see [[LocationStrategy]] for more details.
* @param consumerStrategy In most cases, pass in [[ConsumerStrategies.Subscribe]],
* see [[ConsumerStrategy]] for more details
* @param ppc configuration of settings such as max rate on a per-partition basis.
* see [[PerPartitionConfig]] for more details.
* @tparam K type of Kafka message key
* @tparam V type of Kafka message value
*/
private[spark] class DirectKafkaInputDStream[K, V](
_ssc: StreamingContext,
locationStrategy: LocationStrategy,
consumerStrategy: ConsumerStrategy[K, V],
ppc: PerPartitionConfig
) extends InputDStream[ConsumerRecord[K, V]](_ssc) with Logging with CanCommitOffsets
复制代码
从类注释上可以看到,这个类,是将每个kafka topic的分区直接映射为一个RDD的分区。
这个类继承了InputDStream类,所以我们先看下这个InputDstream类。
InputDStream是所有输入流的基础抽象类,提供了两个需要子类实现的方法,即start和stop方法,源码注释上讲start方法是去读取数据,stop方法是停止读取数据。似乎真正读取数据的就是在start方法里面,所以我们回到DirectKafkaInputDStream这个类的start方法里面。
override def start(): Unit = {
val c = consumer
paranoidPoll(c)
if (currentOffsets.isEmpty) {
currentOffsets = c.assignment().asScala.map { tp =>
tp -> c.position(tp)
}.toMap
}
}
override def stop(): Unit = this.synchronized {
if (kc != null) {
kc.close()
}
}
复制代码
从这里可以看到,DirectKafkaInputDStream类的start方法,仅仅是次去获取Kafka topic当前消费组的offset位置,stop方法,是将Kafka Consumer对象关闭然后释放资源。追踪一下start和stop方法被调用情况发现,这两个方法是在DStreamGraph类中,执行start的时候被调用一次。
org.apache.spark.streaming.DStreamGraph
def start(time: Time) {
this.synchronized {
require(zeroTime == null, "DStream graph computation already started")
zeroTime = time
startTime = time
outputStreams.foreach(_.initialize(zeroTime))
outputStreams.foreach(_.remember(rememberDuration))
outputStreams.foreach(_.validateAtStart())
numReceivers = inputStreams.count(_.isInstanceOf[ReceiverInputDStream[_]])
inputStreamNameAndID = inputStreams.map(is => (is.name, is.id))
inputStreams.par.foreach(_.start())
}
}
复制代码
因此start方法仅仅只是用于启动时初始化某些参数的,比如DirectKafkaInputDStream就是用于获取消费组上次消费停止的offset位置。stop方法用于释放资源的。
我们从流任务启动入口去找,即StreamingContext的start方法。
/**
* Start the execution of the streams.
*
* @throws IllegalStateException if the StreamingContext is already stopped.
*/
def start(): Unit = synchronized {
state match {
case INITIALIZED =>
startSite.set(DStream.getCreationSite())
StreamingContext.ACTIVATION_LOCK.synchronized {
StreamingContext.assertNoOtherContextIsActive()
try {
validate()
// Start the streaming scheduler in a new thread, so that thread local properties
// like call sites and job groups can be reset without affecting those of the
// current thread.
ThreadUtils.runInNewThread("streaming-start") {
sparkContext.setCallSite(startSite.get)
sparkContext.clearJobGroup()
sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false")
savedProperties.set(SerializationUtils.clone(sparkContext.localProperties.get()))
scheduler.start()
}
state = StreamingContextState.ACTIVE
scheduler.listenerBus.post(
StreamingListenerStreamingStarted(System.currentTimeMillis()))
} catch {
case NonFatal(e) =>
logError("Error starting the context, marking it as stopped", e)
scheduler.stop(false)
state = StreamingContextState.STOPPED
throw e
}
StreamingContext.setActiveContext(this)
}
logDebug("Adding shutdown hook") // force eager creation of logger
shutdownHookRef = ShutdownHookManager.addShutdownHook(
StreamingContext.SHUTDOWN_HOOK_PRIORITY)(() => stopOnShutdown())
// Registering Streaming Metrics at the start of the StreamingContext
assert(env.metricsSystem != null)
env.metricsSystem.registerSource(streamingSource)
uiTab.foreach(_.attach())
logInfo("StreamingContext started")
case ACTIVE =>
logWarning("StreamingContext has already been started")
case STOPPED =>
throw new IllegalStateException("StreamingContext has already been stopped")
}
}
复制代码
其中里面主要的是scheduler.start()
,追踪调用链如下:
其中graph.start(startTime - graph.batchDuration)
,就是调用DStreamGraph的start方法,上面已经分析过,会调用到InputDStream的start方法。
主要核心的是timer.start(startTime.milliseconds)
,我们先看下这个timer是什么对象。
private val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds,
longTime => eventLoop.post(GenerateJobs(new Time(longTime))), "JobGenerator")
复制代码
这是一个定时器,会按照指定间隔时长执行传入的函数。
private[streaming]
class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: String)
extends Logging
复制代码
其中callback,就是longTime => eventLoop.post(GenerateJobs(new Time(longTime))
,就是按间隔时长,将生成的GenerateJobs事件发送到eventLoop时间循环处理器上。
回到JobGenerator的start方法,里面就初始化了一个EventLoop对象:
eventLoop = new EventLoop[JobGeneratorEvent]("JobGenerator") {
override protected def onReceive(event: JobGeneratorEvent): Unit = processEvent(event)
override protected def onError(e: Throwable): Unit = {
jobScheduler.reportError("Error in job generator", e)
}
}
复制代码
主要看下processEvent方法:
/** Processes all events */
private def processEvent(event: JobGeneratorEvent) {
logDebug("Got event " + event)
event match {
case GenerateJobs(time) => generateJobs(time)
case ClearMetadata(time) => clearMetadata(time)
case DoCheckpoint(time, clearCheckpointDataLater) =>
doCheckpoint(time, clearCheckpointDataLater)
case ClearCheckpointData(time) => clearCheckpointData(time)
}
}
复制代码
对于GenerateJobs事件,会调用generateJobs方法:
/** Generate jobs and perform checkpointing for the given `time`. */
private def generateJobs(time: Time) {
// Checkpoint all RDDs marked for checkpointing to ensure their lineages are
// truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847).
ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, "true")
Try {
jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate received blocks to batch
graph.generateJobs(time) // generate jobs using allocated block
} match {
case Success(jobs) =>
val streamIdToInputInfos = jobScheduler.inputInfoTracker.getInfo(time)
jobScheduler.submitJobSet(JobSet(time, jobs, streamIdToInputInfos))
case Failure(e) =>
jobScheduler.reportError("Error generating jobs for time " + time, e)
PythonDStream.stopStreamingContextIfPythonProcessIsDead(e)
}
eventLoop.post(DoCheckpoint(time, clearCheckpointDataLater = false))
}
复制代码
其中的graph.generateJobs(time)
,就是调用DStreamGraph的generateJobs方法,生成job。
org.apache.spark.streaming.DStreamGraph
def generateJobs(time: Time): Seq[Job] = {
logDebug("Generating jobs for time " + time)
val jobs = this.synchronized {
outputStreams.flatMap { outputStream =>
val jobOption = outputStream.generateJob(time)
jobOption.foreach(_.setCallSite(outputStream.creationSite))
jobOption
}
}
logDebug("Generated " + jobs.length + " jobs for time " + time)
jobs
}
复制代码
从这里可以看到,是以outputStreams作为引子,切入进去,生成job的,即outputStream.generateJob(time)
。
/**
* Generate a SparkStreaming job for the given time. This is an internal method that
* should not be called directly. This default implementation creates a job
* that materializes the corresponding RDD. Subclasses of DStream may override this
* to generate their own jobs.
*/
private[streaming] def generateJob(time: Time): Option[Job] = {
getOrCompute(time) match {
case Some(rdd) =>
val jobFunc = () => {
val emptyFunc = { (iterator: Iterator[T]) => {} }
context.sparkContext.runJob(rdd, emptyFunc)
}
Some(new Job(time, jobFunc))
case None => None
}
}
复制代码
终调用了DStream的getOrCompute方法:
/**
* Get the RDD corresponding to the given time; either retrieve it from cache
* or compute-and-cache it.
*/
private[streaming] final def getOrCompute(time: Time): Option[RDD[T]] = {
// If RDD was already generated, then retrieve it from HashMap,
// or else compute the RDD
generatedRDDs.get(time).orElse {
// Compute the RDD if time is valid (e.g. correct time in a sliding window)
// of RDD generation, else generate nothing.
if (isTimeValid(time)) {
val rddOption = createRDDWithLocalProperties(time, displayInnerRDDOps = false) {
// Disable checks for existing output directories in jobs launched by the streaming
// scheduler, since we may need to write output to an existing directory during checkpoint
// recovery; see SPARK-4835 for more details. We need to have this call here because
// compute() might cause Spark jobs to be launched.
SparkHadoopWriterUtils.disableOutputSpecValidation.withValue(true) {
compute(time)
}
}
rddOption.foreach { case newRDD =>
// Register the generated RDD for caching and checkpointing
if (storageLevel != StorageLevel.NONE) {
newRDD.persist(storageLevel)
logDebug(s"Persisting RDD ${newRDD.id} for time $time to $storageLevel")
}
if (checkpointDuration != null && (time - zeroTime).isMultipleOf(checkpointDuration)) {
newRDD.checkpoint()
logInfo(s"Marking RDD ${newRDD.id} for time $time for checkpointing")
}
generatedRDDs.put(time, newRDD)
}
rddOption
} else {
None
}
}
}
复制代码
在getOrCompute方法中,发现终调用的是DStream的compute方法生成RDD。
所以我们回到DirectKafkaInputDStream的compute方法:
override def compute(validTime: Time): Option[KafkaRDD[K, V]] = {
val untilOffsets = clamp(latestOffsets())
val offsetRanges = untilOffsets.map { case (tp, uo) =>
val fo = currentOffsets(tp)
OffsetRange(tp.topic, tp.partition, fo, uo)
}
val useConsumerCache = context.conf.getBoolean("spark.streaming.kafka.consumer.cache.enabled",
true)
val rdd = new KafkaRDD[K, V](context.sparkContext, executorKafkaParams, offsetRanges.toArray,
getPreferredHosts, useConsumerCache)
// Report the record number and metadata of this batch interval to InputInfoTracker.
val description = offsetRanges.filter { offsetRange =>
// Don't display empty ranges.
offsetRange.fromOffset != offsetRange.untilOffset
}.map { offsetRange =>
s"topic: ${offsetRange.topic}\tpartition: ${offsetRange.partition}\t" +
s"offsets: ${offsetRange.fromOffset} to ${offsetRange.untilOffset}"
}.mkString("\n")
// Copy offsetRanges to immutable.List to prevent from being modified by the user
val metadata = Map(
"offsets" -> offsetRanges.toList,
StreamInputInfo.METADATA_KEY_DESCRIPTION -> description)
val inputInfo = StreamInputInfo(id, rdd.count, metadata)
ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo)
currentOffsets = untilOffsets
commitAll()
Some(rdd)
}
复制代码
在这里生成了KafkaRDD。 所以结合我们之前RDD生成的分析,终加载数据的地方就是KafkaRDD的compute方法:
override def compute(thePart: Partition, context: TaskContext): Iterator[ConsumerRecord[K, V]] = {
val part = thePart.asInstanceOf[KafkaRDDPartition]
require(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part))
if (part.fromOffset == part.untilOffset) {
logInfo(s"Beginning offset ${part.fromOffset} is the same as ending offset " +
s"skipping ${part.topic} ${part.partition}")
Iterator.empty
} else {
logInfo(s"Computing topic ${part.topic}, partition ${part.partition} " +
s"offsets ${part.fromOffset} -> ${part.untilOffset}")
if (compacted) {
new CompactedKafkaRDDIterator[K, V](
part,
context,
kafkaParams,
useConsumerCache,
pollTimeout,
cacheInitialCapacity,
cacheMaxCapacity,
cacheLoadFactor
)
} else {
new KafkaRDDIterator[K, V](
part,
context,
kafkaParams,
useConsumerCache,
pollTimeout,
cacheInitialCapacity,
cacheMaxCapacity,
cacheLoadFactor
)
}
}
}
复制代码
后是在KafkaRDDIterator或者CompactedKafkaRDDIterator里面从Kafka上获取数据。
整个流程如下:
相关文章