ServerAuthInterceptor.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.flight.auth;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCall.Listener;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import java.util.Optional;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.grpc.StatusUtils;
/** GRPC Interceptor for performing authentication. */
public class ServerAuthInterceptor implements ServerInterceptor {
private final ServerAuthHandler authHandler;
public ServerAuthInterceptor(ServerAuthHandler authHandler) {
this.authHandler = authHandler;
}
@Override
public <ReqT, RespT> Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
if (!call.getMethodDescriptor()
.getFullMethodName()
.equals(AuthConstants.HANDSHAKE_DESCRIPTOR_NAME)) {
final Optional<String> peerIdentity;
// Allow customizing the response code by throwing FlightRuntimeException
try {
peerIdentity = isValid(headers);
} catch (FlightRuntimeException e) {
final Status grpcStatus = StatusUtils.toGrpcStatus(e.status());
call.close(grpcStatus, new Metadata());
return new NoopServerCallListener<>();
} catch (StatusRuntimeException e) {
Metadata trailers = e.getTrailers();
call.close(e.getStatus(), trailers == null ? new Metadata() : trailers);
return new NoopServerCallListener<>();
}
if (!peerIdentity.isPresent()) {
// Send back a description along with the status code
call.close(
Status.UNAUTHENTICATED.withDescription(
"Unauthenticated (invalid or missing auth token)"),
new Metadata());
return new NoopServerCallListener<>();
}
return Contexts.interceptCall(
Context.current().withValue(AuthConstants.PEER_IDENTITY_KEY, peerIdentity.get()),
call,
headers,
next);
}
return next.startCall(call, headers);
}
private Optional<String> isValid(Metadata headers) {
byte[] token = headers.get(AuthConstants.TOKEN_KEY);
return authHandler.isValid(token);
}
private static class NoopServerCallListener<T> extends ServerCall.Listener<T> {}
}