ArrowFlightSqlClientHandler.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.driver.jdbc.client;
import com.google.common.collect.ImmutableMap;
import java.io.IOException;
import java.net.URI;
import java.security.GeneralSecurityException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.apache.arrow.driver.jdbc.client.utils.ClientAuthenticationUtils;
import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.CloseSessionRequest;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightClientMiddleware;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightStatusCode;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.LocationSchemes;
import org.apache.arrow.flight.SessionOptionValue;
import org.apache.arrow.flight.SessionOptionValueFactory;
import org.apache.arrow.flight.SetSessionOptionsRequest;
import org.apache.arrow.flight.SetSessionOptionsResult;
import org.apache.arrow.flight.auth2.BearerCredentialWriter;
import org.apache.arrow.flight.auth2.ClientBearerHeaderHandler;
import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware;
import org.apache.arrow.flight.client.ClientCookieMiddleware;
import org.apache.arrow.flight.grpc.CredentialCallOption;
import org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.flight.sql.impl.FlightSql.SqlInfo;
import org.apache.arrow.flight.sql.util.TableRef;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.calcite.avatica.Meta.StatementType;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** A {@link FlightSqlClient} handler. */
public final class ArrowFlightSqlClientHandler implements AutoCloseable {
private static final Logger LOGGER = LoggerFactory.getLogger(ArrowFlightSqlClientHandler.class);
// JDBC connection string query parameter
private static final String CATALOG = "catalog";
private final FlightSqlClient sqlClient;
private final Set<CallOption> options = new HashSet<>();
private final Builder builder;
private final Optional<String> catalog;
ArrowFlightSqlClientHandler(
final FlightSqlClient sqlClient,
final Builder builder,
final Collection<CallOption> credentialOptions,
final Optional<String> catalog) {
this.options.addAll(builder.options);
this.options.addAll(credentialOptions);
this.sqlClient = Preconditions.checkNotNull(sqlClient);
this.builder = builder;
this.catalog = catalog;
}
/**
* Creates a new {@link ArrowFlightSqlClientHandler} from the provided {@code client} and {@code
* options}.
*
* @param client the {@link FlightClient} to manage under a {@link FlightSqlClient} wrapper.
* @param options the {@link CallOption}s to persist in between subsequent client calls.
* @return a new {@link ArrowFlightSqlClientHandler}.
*/
static ArrowFlightSqlClientHandler createNewHandler(
final FlightClient client,
final Builder builder,
final Collection<CallOption> options,
final Optional<String> catalog) {
final ArrowFlightSqlClientHandler handler =
new ArrowFlightSqlClientHandler(new FlightSqlClient(client), builder, options, catalog);
handler.setSetCatalogInSessionIfPresent();
return handler;
}
/**
* Gets the {@link #options} for the subsequent calls from this handler.
*
* @return the {@link CallOption}s.
*/
private CallOption[] getOptions() {
return options.toArray(new CallOption[0]);
}
/**
* Makes an RPC "getStream" request based on the provided {@link FlightInfo} object. Retrieves the
* result of the query previously prepared with "getInfo."
*
* @param flightInfo The {@link FlightInfo} instance from which to fetch results.
* @return a {@code FlightStream} of results.
*/
public List<CloseableEndpointStreamPair> getStreams(final FlightInfo flightInfo)
throws SQLException {
final ArrayList<CloseableEndpointStreamPair> endpoints =
new ArrayList<>(flightInfo.getEndpoints().size());
try {
for (FlightEndpoint endpoint : flightInfo.getEndpoints()) {
if (endpoint.getLocations().isEmpty()) {
// Create a stream using the current client only and do not close the client at the end.
endpoints.add(
new CloseableEndpointStreamPair(
sqlClient.getStream(endpoint.getTicket(), getOptions()), null));
} else {
// Clone the builder and then set the new endpoint on it.
// GH-38574: Currently a new FlightClient will be made for each partition that returns a
// non-empty Location
// then disposed of. It may be better to cache clients because a server may report the
// same Locations.
// It would also be good to identify when the reported location is the same as the
// original connection's
// Location and skip creating a FlightClient in that scenario.
List<Exception> exceptions = new ArrayList<>();
CloseableEndpointStreamPair stream = null;
for (Location location : endpoint.getLocations()) {
final URI endpointUri = location.getUri();
if (endpointUri.getScheme().equals(LocationSchemes.REUSE_CONNECTION)) {
stream =
new CloseableEndpointStreamPair(
sqlClient.getStream(endpoint.getTicket(), getOptions()), null);
break;
}
final Builder builderForEndpoint =
new Builder(ArrowFlightSqlClientHandler.this.builder)
.withHost(endpointUri.getHost())
.withPort(endpointUri.getPort())
.withEncryption(endpointUri.getScheme().equals(LocationSchemes.GRPC_TLS));
ArrowFlightSqlClientHandler endpointHandler = null;
try {
endpointHandler = builderForEndpoint.build();
stream =
new CloseableEndpointStreamPair(
endpointHandler.sqlClient.getStream(
endpoint.getTicket(), endpointHandler.getOptions()),
endpointHandler.sqlClient);
// Make sure we actually get data from the server
stream.getStream().getSchema();
} catch (Exception ex) {
if (endpointHandler != null) {
AutoCloseables.close(endpointHandler);
}
exceptions.add(ex);
continue;
}
break;
}
if (stream != null) {
endpoints.add(stream);
} else if (exceptions.isEmpty()) {
// This should never happen...
throw new IllegalStateException("Could not connect to endpoint and no errors occurred");
} else {
Exception ex = exceptions.remove(0);
while (!exceptions.isEmpty()) {
ex.addSuppressed(exceptions.remove(exceptions.size() - 1));
}
throw ex;
}
}
}
} catch (Exception outerException) {
try {
AutoCloseables.close(endpoints);
} catch (Exception innerEx) {
outerException.addSuppressed(innerEx);
}
if (outerException instanceof SQLException) {
throw (SQLException) outerException;
}
throw new SQLException(outerException);
}
return endpoints;
}
/**
* Makes an RPC "getInfo" request based on the provided {@code query} object.
*
* @param query The query.
* @return a {@code FlightStream} of results.
*/
public FlightInfo getInfo(final String query) {
return sqlClient.execute(query, getOptions());
}
@Override
public void close() throws SQLException {
if (catalog.isPresent()) {
sqlClient.closeSession(new CloseSessionRequest(), getOptions());
}
try {
AutoCloseables.close(sqlClient);
} catch (final Exception e) {
throw new SQLException("Failed to clean up client resources.", e);
}
}
/** A prepared statement handler. */
public interface PreparedStatement extends AutoCloseable {
/**
* Executes this {@link PreparedStatement}.
*
* @return the {@link FlightInfo} representing the outcome of this query execution.
* @throws SQLException on error.
*/
FlightInfo executeQuery() throws SQLException;
/**
* Executes a {@link StatementType#UPDATE} query.
*
* @return the number of rows affected.
*/
long executeUpdate();
/**
* Gets the {@link StatementType} of this {@link PreparedStatement}.
*
* @return the Statement Type.
*/
StatementType getType();
/**
* Gets the {@link Schema} of this {@link PreparedStatement}.
*
* @return {@link Schema}.
*/
Schema getDataSetSchema();
/**
* Gets the {@link Schema} of the parameters for this {@link PreparedStatement}.
*
* @return {@link Schema}.
*/
Schema getParameterSchema();
void setParameters(VectorSchemaRoot parameters);
@Override
void close();
}
/** A connection is created with catalog set as a session option. */
private void setSetCatalogInSessionIfPresent() {
if (catalog.isPresent()) {
final SetSessionOptionsRequest setSessionOptionRequest =
new SetSessionOptionsRequest(
ImmutableMap.<String, SessionOptionValue>builder()
.put(CATALOG, SessionOptionValueFactory.makeSessionOptionValue(catalog.get()))
.build());
final SetSessionOptionsResult result =
sqlClient.setSessionOptions(setSessionOptionRequest, getOptions());
if (result.hasErrors()) {
Map<String, SetSessionOptionsResult.Error> errors = result.getErrors();
for (Map.Entry<String, SetSessionOptionsResult.Error> error : errors.entrySet()) {
LOGGER.warn(error.toString());
}
throw CallStatus.INVALID_ARGUMENT
.withDescription(
String.format(
"Cannot set session option for catalog = %s. Check log for details.", catalog))
.toRuntimeException();
}
}
}
/**
* Creates a new {@link PreparedStatement} for the given {@code query}.
*
* @param query the SQL query.
* @return a new prepared statement.
*/
public PreparedStatement prepare(final String query) {
final FlightSqlClient.PreparedStatement preparedStatement =
sqlClient.prepare(query, getOptions());
return new PreparedStatement() {
@Override
public FlightInfo executeQuery() throws SQLException {
return preparedStatement.execute(getOptions());
}
@Override
public long executeUpdate() {
return preparedStatement.executeUpdate(getOptions());
}
@Override
public StatementType getType() {
final Schema schema = preparedStatement.getResultSetSchema();
return schema.getFields().isEmpty() ? StatementType.UPDATE : StatementType.SELECT;
}
@Override
public Schema getDataSetSchema() {
return preparedStatement.getResultSetSchema();
}
@Override
public Schema getParameterSchema() {
return preparedStatement.getParameterSchema();
}
@Override
public void setParameters(VectorSchemaRoot parameters) {
preparedStatement.setParameters(parameters);
}
@Override
public void close() {
try {
preparedStatement.close(getOptions());
} catch (FlightRuntimeException fre) {
// ARROW-17785: suppress exceptions caused by flaky gRPC layer
if (fre.status().code().equals(FlightStatusCode.UNAVAILABLE)
|| (fre.status().code().equals(FlightStatusCode.INTERNAL)
&& fre.getMessage().contains("Connection closed after GOAWAY"))) {
LOGGER.warn("Supressed error closing PreparedStatement", fre);
return;
}
throw fre;
}
}
};
}
/**
* Makes an RPC "getCatalogs" request.
*
* @return a {@code FlightStream} of results.
*/
public FlightInfo getCatalogs() {
return sqlClient.getCatalogs(getOptions());
}
/**
* Makes an RPC "getImportedKeys" request based on the provided info.
*
* @param catalog The catalog name. Must match the catalog name as it is stored in the database.
* Retrieves those without a catalog. Null means that the catalog name should not be used to
* narrow the search.
* @param schema The schema name. Must match the schema name as it is stored in the database. ""
* retrieves those without a schema. Null means that the schema name should not be used to
* narrow the search.
* @param table The table name. Must match the table name as it is stored in the database.
* @return a {@code FlightStream} of results.
*/
public FlightInfo getImportedKeys(final String catalog, final String schema, final String table) {
return sqlClient.getImportedKeys(TableRef.of(catalog, schema, table), getOptions());
}
/**
* Makes an RPC "getExportedKeys" request based on the provided info.
*
* @param catalog The catalog name. Must match the catalog name as it is stored in the database.
* Retrieves those without a catalog. Null means that the catalog name should not be used to
* narrow the search.
* @param schema The schema name. Must match the schema name as it is stored in the database. ""
* retrieves those without a schema. Null means that the schema name should not be used to
* narrow the search.
* @param table The table name. Must match the table name as it is stored in the database.
* @return a {@code FlightStream} of results.
*/
public FlightInfo getExportedKeys(final String catalog, final String schema, final String table) {
return sqlClient.getExportedKeys(TableRef.of(catalog, schema, table), getOptions());
}
/**
* Makes an RPC "getSchemas" request based on the provided info.
*
* @param catalog The catalog name. Must match the catalog name as it is stored in the database.
* Retrieves those without a catalog. Null means that the catalog name should not be used to
* narrow the search.
* @param schemaPattern The schema name pattern. Must match the schema name as it is stored in the
* database. Null means that schema name should not be used to narrow down the search.
* @return a {@code FlightStream} of results.
*/
public FlightInfo getSchemas(final String catalog, final String schemaPattern) {
return sqlClient.getSchemas(catalog, schemaPattern, getOptions());
}
/**
* Makes an RPC "getTableTypes" request.
*
* @return a {@code FlightStream} of results.
*/
public FlightInfo getTableTypes() {
return sqlClient.getTableTypes(getOptions());
}
/**
* Makes an RPC "getTables" request based on the provided info.
*
* @param catalog The catalog name. Must match the catalog name as it is stored in the database.
* Retrieves those without a catalog. Null means that the catalog name should not be used to
* narrow the search.
* @param schemaPattern The schema name pattern. Must match the schema name as it is stored in the
* database. "" retrieves those without a schema. Null means that the schema name should not
* be used to narrow the search.
* @param tableNamePattern The table name pattern. Must match the table name as it is stored in
* the database.
* @param types The list of table types, which must be from the list of table types to include.
* Null returns all types.
* @param includeSchema Whether to include schema.
* @return a {@code FlightStream} of results.
*/
public FlightInfo getTables(
final String catalog,
final String schemaPattern,
final String tableNamePattern,
final List<String> types,
final boolean includeSchema) {
return sqlClient.getTables(
catalog, schemaPattern, tableNamePattern, types, includeSchema, getOptions());
}
/**
* Gets SQL info.
*
* @return the SQL info.
*/
public FlightInfo getSqlInfo(SqlInfo... info) {
return sqlClient.getSqlInfo(info, getOptions());
}
/**
* Makes an RPC "getPrimaryKeys" request based on the provided info.
*
* @param catalog The catalog name; must match the catalog name as it is stored in the database.
* "" retrieves those without a catalog. Null means that the catalog name should not be used
* to narrow the search.
* @param schema The schema name; must match the schema name as it is stored in the database. ""
* retrieves those without a schema. Null means that the schema name should not be used to
* narrow the search.
* @param table The table name. Must match the table name as it is stored in the database.
* @return a {@code FlightStream} of results.
*/
public FlightInfo getPrimaryKeys(final String catalog, final String schema, final String table) {
return sqlClient.getPrimaryKeys(TableRef.of(catalog, schema, table), getOptions());
}
/**
* Makes an RPC "getCrossReference" request based on the provided info.
*
* @param pkCatalog The catalog name. Must match the catalog name as it is stored in the database.
* Retrieves those without a catalog. Null means that the catalog name should not be used to
* narrow the search.
* @param pkSchema The schema name. Must match the schema name as it is stored in the database. ""
* retrieves those without a schema. Null means that the schema name should not be used to
* narrow the search.
* @param pkTable The table name. Must match the table name as it is stored in the database.
* @param fkCatalog The catalog name. Must match the catalog name as it is stored in the database.
* Retrieves those without a catalog. Null means that the catalog name should not be used to
* narrow the search.
* @param fkSchema The schema name. Must match the schema name as it is stored in the database. ""
* retrieves those without a schema. Null means that the schema name should not be used to
* narrow the search.
* @param fkTable The table name. Must match the table name as it is stored in the database.
* @return a {@code FlightStream} of results.
*/
public FlightInfo getCrossReference(
String pkCatalog,
String pkSchema,
String pkTable,
String fkCatalog,
String fkSchema,
String fkTable) {
return sqlClient.getCrossReference(
TableRef.of(pkCatalog, pkSchema, pkTable),
TableRef.of(fkCatalog, fkSchema, fkTable),
getOptions());
}
/** Builder for {@link ArrowFlightSqlClientHandler}. */
public static final class Builder {
private final Set<FlightClientMiddleware.Factory> middlewareFactories = new HashSet<>();
private final Set<CallOption> options = new HashSet<>();
private String host;
private int port;
@VisibleForTesting String username;
@VisibleForTesting String password;
@VisibleForTesting String trustStorePath;
@VisibleForTesting String trustStorePassword;
@VisibleForTesting String token;
@VisibleForTesting boolean useEncryption = true;
@VisibleForTesting boolean disableCertificateVerification;
@VisibleForTesting boolean useSystemTrustStore = true;
@VisibleForTesting String tlsRootCertificatesPath;
@VisibleForTesting String clientCertificatePath;
@VisibleForTesting String clientKeyPath;
@VisibleForTesting private BufferAllocator allocator;
@VisibleForTesting boolean retainCookies = true;
@VisibleForTesting boolean retainAuth = true;
@VisibleForTesting Optional<String> catalog = Optional.empty();
// These two middleware are for internal use within build() and should not be exposed by builder
// APIs.
// Note that these middleware may not necessarily be registered.
@VisibleForTesting
ClientIncomingAuthHeaderMiddleware.Factory authFactory =
new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler());
@VisibleForTesting
ClientCookieMiddleware.Factory cookieFactory = new ClientCookieMiddleware.Factory();
public Builder() {}
/**
* Copies the builder.
*
* @param original The builder to base this copy off of.
*/
@VisibleForTesting
Builder(Builder original) {
this.middlewareFactories.addAll(original.middlewareFactories);
this.options.addAll(original.options);
this.host = original.host;
this.port = original.port;
this.username = original.username;
this.password = original.password;
this.trustStorePath = original.trustStorePath;
this.trustStorePassword = original.trustStorePassword;
this.token = original.token;
this.useEncryption = original.useEncryption;
this.disableCertificateVerification = original.disableCertificateVerification;
this.useSystemTrustStore = original.useSystemTrustStore;
this.tlsRootCertificatesPath = original.tlsRootCertificatesPath;
this.clientCertificatePath = original.clientCertificatePath;
this.clientKeyPath = original.clientKeyPath;
this.allocator = original.allocator;
this.catalog = original.catalog;
if (original.retainCookies) {
this.cookieFactory = original.cookieFactory;
}
if (original.retainAuth) {
this.authFactory = original.authFactory;
}
}
/**
* Sets the host for this handler.
*
* @param host the host.
* @return this instance.
*/
public Builder withHost(final String host) {
this.host = host;
return this;
}
/**
* Sets the port for this handler.
*
* @param port the port.
* @return this instance.
*/
public Builder withPort(final int port) {
this.port = port;
return this;
}
/**
* Sets the username for this handler.
*
* @param username the username.
* @return this instance.
*/
public Builder withUsername(final String username) {
this.username = username;
return this;
}
/**
* Sets the password for this handler.
*
* @param password the password.
* @return this instance.
*/
public Builder withPassword(final String password) {
this.password = password;
return this;
}
/**
* Sets the KeyStore path for this handler.
*
* @param trustStorePath the KeyStore path.
* @return this instance.
*/
public Builder withTrustStorePath(final String trustStorePath) {
this.trustStorePath = trustStorePath;
return this;
}
/**
* Sets the KeyStore password for this handler.
*
* @param trustStorePassword the KeyStore password.
* @return this instance.
*/
public Builder withTrustStorePassword(final String trustStorePassword) {
this.trustStorePassword = trustStorePassword;
return this;
}
/**
* Sets whether to use TLS encryption in this handler.
*
* @param useEncryption whether to use TLS encryption.
* @return this instance.
*/
public Builder withEncryption(final boolean useEncryption) {
this.useEncryption = useEncryption;
return this;
}
/**
* Sets whether to disable the certificate verification in this handler.
*
* @param disableCertificateVerification whether to disable certificate verification.
* @return this instance.
*/
public Builder withDisableCertificateVerification(
final boolean disableCertificateVerification) {
this.disableCertificateVerification = disableCertificateVerification;
return this;
}
/**
* Sets whether to use the certificates from the operating system.
*
* @param useSystemTrustStore whether to use the system operating certificates.
* @return this instance.
*/
public Builder withSystemTrustStore(final boolean useSystemTrustStore) {
this.useSystemTrustStore = useSystemTrustStore;
return this;
}
/**
* Sets the TLS root certificate path as an alternative to using the System or other Trust
* Store. The path must contain a valid PEM file.
*
* @param tlsRootCertificatesPath the TLS root certificate path (if TLS is required).
* @return this instance.
*/
public Builder withTlsRootCertificates(final String tlsRootCertificatesPath) {
this.tlsRootCertificatesPath = tlsRootCertificatesPath;
return this;
}
/**
* Sets the mTLS client certificate path (if mTLS is required).
*
* @param clientCertificatePath the mTLS client certificate path (if mTLS is required).
* @return this instance.
*/
public Builder withClientCertificate(final String clientCertificatePath) {
this.clientCertificatePath = clientCertificatePath;
return this;
}
/**
* Sets the mTLS client certificate private key path (if mTLS is required).
*
* @param clientKeyPath the mTLS client certificate private key path (if mTLS is required).
* @return this instance.
*/
public Builder withClientKey(final String clientKeyPath) {
this.clientKeyPath = clientKeyPath;
return this;
}
/**
* Sets the token used in the token authentication.
*
* @param token the token value.
* @return this builder instance.
*/
public Builder withToken(final String token) {
this.token = token;
return this;
}
/**
* Sets the {@link BufferAllocator} to use in this handler.
*
* @param allocator the allocator.
* @return this instance.
*/
public Builder withBufferAllocator(final BufferAllocator allocator) {
this.allocator =
allocator.newChildAllocator("ArrowFlightSqlClientHandler", 0, allocator.getLimit());
return this;
}
/**
* Indicates if cookies should be re-used by connections spawned for getStreams() calls.
*
* @param retainCookies The flag indicating if cookies should be re-used.
* @return this builder instance.
*/
public Builder withRetainCookies(boolean retainCookies) {
this.retainCookies = retainCookies;
return this;
}
/**
* Indicates if bearer tokens negotiated should be re-used by connections spawned for
* getStreams() calls.
*
* @param retainAuth The flag indicating if auth tokens should be re-used.
* @return this builder instance.
*/
public Builder withRetainAuth(boolean retainAuth) {
this.retainAuth = retainAuth;
return this;
}
/**
* Adds the provided {@code factories} to the list of {@link #middlewareFactories} of this
* handler.
*
* @param factories the factories to add.
* @return this instance.
*/
public Builder withMiddlewareFactories(final FlightClientMiddleware.Factory... factories) {
return withMiddlewareFactories(Arrays.asList(factories));
}
/**
* Adds the provided {@code factories} to the list of {@link #middlewareFactories} of this
* handler.
*
* @param factories the factories to add.
* @return this instance.
*/
public Builder withMiddlewareFactories(
final Collection<FlightClientMiddleware.Factory> factories) {
this.middlewareFactories.addAll(factories);
return this;
}
/**
* Adds the provided {@link CallOption}s to this handler.
*
* @param options the options
* @return this instance.
*/
public Builder withCallOptions(final CallOption... options) {
return withCallOptions(Arrays.asList(options));
}
/**
* Adds the provided {@link CallOption}s to this handler.
*
* @param options the options
* @return this instance.
*/
public Builder withCallOptions(final Collection<CallOption> options) {
this.options.addAll(options);
return this;
}
/**
* Sets the catalog for this handler if it is not null.
*
* @param catalog the catalog
* @return this instance.
*/
public Builder withCatalog(@Nullable final String catalog) {
this.catalog = Optional.ofNullable(catalog);
return this;
}
/**
* Builds a new {@link ArrowFlightSqlClientHandler} from the provided fields.
*
* @return a new client handler.
* @throws SQLException on error.
*/
public ArrowFlightSqlClientHandler build() throws SQLException {
// Copy middleware so that the build method doesn't change the state of the builder fields
// itself.
Set<FlightClientMiddleware.Factory> buildTimeMiddlewareFactories =
new HashSet<>(this.middlewareFactories);
FlightClient client = null;
boolean isUsingUserPasswordAuth = username != null && token == null;
try {
// Token should take priority since some apps pass in a username/password even when a token
// is provided
if (isUsingUserPasswordAuth) {
buildTimeMiddlewareFactories.add(authFactory);
}
final FlightClient.Builder clientBuilder = FlightClient.builder().allocator(allocator);
buildTimeMiddlewareFactories.add(new ClientCookieMiddleware.Factory());
buildTimeMiddlewareFactories.forEach(clientBuilder::intercept);
Location location;
if (useEncryption) {
location = Location.forGrpcTls(host, port);
clientBuilder.useTls();
} else {
location = Location.forGrpcInsecure(host, port);
}
clientBuilder.location(location);
if (useEncryption) {
if (disableCertificateVerification) {
clientBuilder.verifyServer(false);
} else {
if (tlsRootCertificatesPath != null) {
clientBuilder.trustedCertificates(
ClientAuthenticationUtils.getTlsRootCertificatesStream(tlsRootCertificatesPath));
} else if (useSystemTrustStore) {
clientBuilder.trustedCertificates(
ClientAuthenticationUtils.getCertificateInputStreamFromSystem(
trustStorePassword));
} else if (trustStorePath != null) {
clientBuilder.trustedCertificates(
ClientAuthenticationUtils.getCertificateStream(
trustStorePath, trustStorePassword));
}
}
if (clientCertificatePath != null && clientKeyPath != null) {
clientBuilder.clientCertificate(
ClientAuthenticationUtils.getClientCertificateStream(clientCertificatePath),
ClientAuthenticationUtils.getClientKeyStream(clientKeyPath));
}
}
client = clientBuilder.build();
final ArrayList<CallOption> credentialOptions = new ArrayList<>();
if (isUsingUserPasswordAuth) {
// If the authFactory has already been used for a handshake, use the existing token.
// This can occur if the authFactory is being re-used for a new connection spawned for
// getStream().
if (authFactory.getCredentialCallOption() != null) {
credentialOptions.add(authFactory.getCredentialCallOption());
} else {
// Otherwise do the handshake and get the token if possible.
credentialOptions.add(
ClientAuthenticationUtils.getAuthenticate(
client, username, password, authFactory, options.toArray(new CallOption[0])));
}
} else if (token != null) {
credentialOptions.add(
ClientAuthenticationUtils.getAuthenticate(
client,
new CredentialCallOption(new BearerCredentialWriter(token)),
options.toArray(new CallOption[0])));
}
return ArrowFlightSqlClientHandler.createNewHandler(
client, this, credentialOptions, catalog);
} catch (final IllegalArgumentException
| GeneralSecurityException
| IOException
| FlightRuntimeException e) {
final SQLException originalException = new SQLException(e);
if (client != null) {
try {
client.close();
} catch (final InterruptedException interruptedException) {
originalException.addSuppressed(interruptedException);
}
}
throw originalException;
}
}
}
}