org.apache.spark.network.client.TransportResponseHandler.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.spark.network.client.TransportResponseHandler.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.network.client;

import java.io.IOException;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicLong;

import com.google.common.annotations.VisibleForTesting;
import io.netty.channel.Channel;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.protocol.ChunkFetchFailure;
import org.apache.spark.network.protocol.ChunkFetchSuccess;
import org.apache.spark.network.protocol.ResponseMessage;
import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcResponse;
import org.apache.spark.network.protocol.StreamChunkId;
import org.apache.spark.network.protocol.StreamFailure;
import org.apache.spark.network.protocol.StreamResponse;
import org.apache.spark.network.server.MessageHandler;
import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
import org.apache.spark.network.util.TransportFrameDecoder;

/**
 * Handler that processes server responses, in response to requests issued from a
 * [[TransportClient]]. It works by tracking the list of outstanding requests (and their callbacks).
 *
 * Concurrency: thread safe and can be called from multiple threads.
 */
public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
    private static final Logger logger = LoggerFactory.getLogger(TransportResponseHandler.class);

    private final Channel channel;

    private final Map<StreamChunkId, ChunkReceivedCallback> outstandingFetches;

    private final Map<Long, RpcResponseCallback> outstandingRpcs;

    private final Queue<Pair<String, StreamCallback>> streamCallbacks;
    private volatile boolean streamActive;

    /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */
    private final AtomicLong timeOfLastRequestNs;

    public TransportResponseHandler(Channel channel) {
        this.channel = channel;
        this.outstandingFetches = new ConcurrentHashMap<>();
        this.outstandingRpcs = new ConcurrentHashMap<>();
        this.streamCallbacks = new ConcurrentLinkedQueue<>();
        this.timeOfLastRequestNs = new AtomicLong(0);
    }

    public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) {
        updateTimeOfLastRequest();
        outstandingFetches.put(streamChunkId, callback);
    }

    public void removeFetchRequest(StreamChunkId streamChunkId) {
        outstandingFetches.remove(streamChunkId);
    }

    public void addRpcRequest(long requestId, RpcResponseCallback callback) {
        updateTimeOfLastRequest();
        outstandingRpcs.put(requestId, callback);
    }

    public void removeRpcRequest(long requestId) {
        outstandingRpcs.remove(requestId);
    }

    public void addStreamCallback(String streamId, StreamCallback callback) {
        timeOfLastRequestNs.set(System.nanoTime());
        streamCallbacks.offer(ImmutablePair.of(streamId, callback));
    }

    @VisibleForTesting
    public void deactivateStream() {
        streamActive = false;
    }

    /**
     * Fire the failure callback for all outstanding requests. This is called when we have an
     * uncaught exception or pre-mature connection termination.
     */
    private void failOutstandingRequests(Throwable cause) {
        for (Map.Entry<StreamChunkId, ChunkReceivedCallback> entry : outstandingFetches.entrySet()) {
            try {
                entry.getValue().onFailure(entry.getKey().chunkIndex, cause);
            } catch (Exception e) {
                logger.warn("ChunkReceivedCallback.onFailure throws exception", e);
            }
        }
        for (Map.Entry<Long, RpcResponseCallback> entry : outstandingRpcs.entrySet()) {
            try {
                entry.getValue().onFailure(cause);
            } catch (Exception e) {
                logger.warn("RpcResponseCallback.onFailure throws exception", e);
            }
        }
        for (Pair<String, StreamCallback> entry : streamCallbacks) {
            try {
                entry.getValue().onFailure(entry.getKey(), cause);
            } catch (Exception e) {
                logger.warn("StreamCallback.onFailure throws exception", e);
            }
        }

        // It's OK if new fetches appear, as they will fail immediately.
        outstandingFetches.clear();
        outstandingRpcs.clear();
        streamCallbacks.clear();
    }

    @Override
    public void channelActive() {
    }

    @Override
    public void channelInactive() {
        if (numOutstandingRequests() > 0) {
            String remoteAddress = getRemoteAddress(channel);
            logger.error("Still have {} requests outstanding when connection from {} is closed",
                    numOutstandingRequests(), remoteAddress);
            failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed"));
        }
    }

    @Override
    public void exceptionCaught(Throwable cause) {
        if (numOutstandingRequests() > 0) {
            String remoteAddress = getRemoteAddress(channel);
            logger.error("Still have {} requests outstanding when connection from {} is closed",
                    numOutstandingRequests(), remoteAddress);
            failOutstandingRequests(cause);
        }
    }

    @Override
    public void handle(ResponseMessage message) throws Exception {
        if (message instanceof ChunkFetchSuccess) {
            ChunkFetchSuccess resp = (ChunkFetchSuccess) message;
            ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
            if (listener == null) {
                logger.warn("Ignoring response for block {} from {} since it is not outstanding",
                        resp.streamChunkId, getRemoteAddress(channel));
                resp.body().release();
            } else {
                outstandingFetches.remove(resp.streamChunkId);
                listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body());
                resp.body().release();
            }
        } else if (message instanceof ChunkFetchFailure) {
            ChunkFetchFailure resp = (ChunkFetchFailure) message;
            ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
            if (listener == null) {
                logger.warn("Ignoring response for block {} from {} ({}) since it is not outstanding",
                        resp.streamChunkId, getRemoteAddress(channel), resp.errorString);
            } else {
                outstandingFetches.remove(resp.streamChunkId);
                listener.onFailure(resp.streamChunkId.chunkIndex, new ChunkFetchFailureException(
                        "Failure while fetching " + resp.streamChunkId + ": " + resp.errorString));
            }
        } else if (message instanceof RpcResponse) {
            RpcResponse resp = (RpcResponse) message;
            RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
            if (listener == null) {
                logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding",
                        resp.requestId, getRemoteAddress(channel), resp.body().size());
            } else {
                outstandingRpcs.remove(resp.requestId);
                try {
                    listener.onSuccess(resp.body().nioByteBuffer());
                } finally {
                    resp.body().release();
                }
            }
        } else if (message instanceof RpcFailure) {
            RpcFailure resp = (RpcFailure) message;
            RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
            if (listener == null) {
                logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding", resp.requestId,
                        getRemoteAddress(channel), resp.errorString);
            } else {
                outstandingRpcs.remove(resp.requestId);
                listener.onFailure(new RuntimeException(resp.errorString));
            }
        } else if (message instanceof StreamResponse) {
            StreamResponse resp = (StreamResponse) message;
            Pair<String, StreamCallback> entry = streamCallbacks.poll();
            if (entry != null) {
                StreamCallback callback = entry.getValue();
                if (resp.byteCount > 0) {
                    StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount,
                            callback);
                    try {
                        TransportFrameDecoder frameDecoder = (TransportFrameDecoder) channel.pipeline()
                                .get(TransportFrameDecoder.HANDLER_NAME);
                        frameDecoder.setInterceptor(interceptor);
                        streamActive = true;
                    } catch (Exception e) {
                        logger.error("Error installing stream handler.", e);
                        deactivateStream();
                    }
                } else {
                    try {
                        callback.onComplete(resp.streamId);
                    } catch (Exception e) {
                        logger.warn("Error in stream handler onComplete().", e);
                    }
                }
            } else {
                logger.error("Could not find callback for StreamResponse.");
            }
        } else if (message instanceof StreamFailure) {
            StreamFailure resp = (StreamFailure) message;
            Pair<String, StreamCallback> entry = streamCallbacks.poll();
            if (entry != null) {
                StreamCallback callback = entry.getValue();
                try {
                    callback.onFailure(resp.streamId, new RuntimeException(resp.error));
                } catch (IOException ioe) {
                    logger.warn("Error in stream failure handler.", ioe);
                }
            } else {
                logger.warn("Stream failure with unknown callback: {}", resp.error);
            }
        } else {
            throw new IllegalStateException("Unknown response type: " + message.type());
        }
    }

    /** Returns total number of outstanding requests (fetch requests + rpcs) */
    public int numOutstandingRequests() {
        return outstandingFetches.size() + outstandingRpcs.size() + streamCallbacks.size() + (streamActive ? 1 : 0);
    }

    /** Returns the time in nanoseconds of when the last request was sent out. */
    public long getTimeOfLastRequestNs() {
        return timeOfLastRequestNs.get();
    }

    /** Updates the time of the last request to the current system time. */
    public void updateTimeOfLastRequest() {
        timeOfLastRequestNs.set(System.nanoTime());
    }

}