diff --git a/v3/src/main/java/com/skyflow/VaultClient.java b/v3/src/main/java/com/skyflow/VaultClient.java index a22fde18..f7319d29 100644 --- a/v3/src/main/java/com/skyflow/VaultClient.java +++ b/v3/src/main/java/com/skyflow/VaultClient.java @@ -1,5 +1,9 @@ package com.skyflow; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; + import com.skyflow.config.Credentials; import com.skyflow.config.VaultConfig; import com.skyflow.enums.UpsertType; @@ -19,19 +23,17 @@ import com.skyflow.utils.Utils; import com.skyflow.utils.logger.LogUtil; import com.skyflow.utils.validations.Validations; +import com.skyflow.vault.data.DeleteTokensRequest; import com.skyflow.vault.data.DetokenizeRequest; import com.skyflow.vault.data.InsertRecord; -import com.skyflow.vault.data.DeleteTokensRequest; -import com.skyflow.vault.data.TokenizeRequest; import com.skyflow.vault.data.TokenizeRecord; +import com.skyflow.vault.data.TokenizeRequest; import io.github.cdimascio.dotenv.Dotenv; import io.github.cdimascio.dotenv.DotenvException; +import okhttp3.ConnectionPool; import okhttp3.OkHttpClient; import okhttp3.Request; -import java.util.ArrayList; -import java.util.List; - public class VaultClient { private final VaultConfig vaultConfig; @@ -41,6 +43,8 @@ public class VaultClient { private Credentials finalCredentials; private String token; private String apiKey; + private OkHttpClient sharedHttpClient = null; + private String currentVaultURL = null; protected VaultClient(VaultConfig vaultConfig, Credentials credentials) throws SkyflowException { super(); @@ -78,10 +82,12 @@ protected void setBearerToken() throws SkyflowException { } else { LogUtil.printInfoLog(InfoLogs.REUSE_BEARER_TOKEN.getLog()); } - updateExecutorInHTTP(); // update executor - this.apiClient = this.apiClientBuilder.build(); - } + if (apiClient == null) { + updateExecutorInHTTP(); + this.apiClient = this.apiClientBuilder.build(); + } + } private void updateVaultURL() throws SkyflowException { // Fetch vaultURL from ENV String vaultURL = Utils.getEnvVaultURL(); @@ -96,6 +102,10 @@ private void updateVaultURL() throws SkyflowException { vaultURL = Utils.getVaultURL(this.vaultConfig.getClusterId(), this.vaultConfig.getEnv()); } this.apiClientBuilder.url(vaultURL); + if (!vaultURL.equals(this.currentVaultURL)) { + this.currentVaultURL = vaultURL; + this.apiClient = null; + } } private void prioritiseCredentials() throws SkyflowException { @@ -132,16 +142,18 @@ private void prioritiseCredentials() throws SkyflowException { } protected void updateExecutorInHTTP() { - OkHttpClient httpClient = new OkHttpClient.Builder() - .addInterceptor(chain -> { - Request original = chain.request(); - Request requestWithAuth = original.newBuilder() - .header("Authorization", "Bearer " + this.token) - .build(); - return chain.proceed(requestWithAuth); - }) - .build(); - apiClientBuilder.httpClient(httpClient); + if (sharedHttpClient == null) { + sharedHttpClient = new OkHttpClient.Builder() + .connectionPool(new ConnectionPool(10, 1, TimeUnit.MINUTES)) + .addInterceptor(chain -> { + Request requestWithAuth = chain.request().newBuilder() + .header("Authorization", "Bearer " + this.token) + .build(); + return chain.proceed(requestWithAuth); + }) + .build(); + apiClientBuilder.httpClient(sharedHttpClient); + } } protected V1InsertRequest getBulkInsertRequestBody(com.skyflow.vault.data.InsertRequest request, VaultConfig config) { diff --git a/v3/src/test/java/com/skyflow/VaultClientTests.java b/v3/src/test/java/com/skyflow/VaultClientTests.java index a782252d..b52272dc 100644 --- a/v3/src/test/java/com/skyflow/VaultClientTests.java +++ b/v3/src/test/java/com/skyflow/VaultClientTests.java @@ -3,6 +3,9 @@ import static org.junit.Assert.*; import com.skyflow.config.Credentials; import com.skyflow.config.VaultConfig; +import okhttp3.ConnectionPool; +import okhttp3.OkHttpClient; +import java.util.concurrent.TimeUnit; import com.skyflow.enums.Env; import com.skyflow.enums.UpsertType; import com.skyflow.errors.ErrorCode; @@ -477,6 +480,71 @@ public void testUpdateExecutorInHTTP() throws Exception { Assert.assertNotNull(recordsApi); } + @Test + public void testConnectionPoolMaxIdleConnections() throws Exception { + Credentials credentials = new Credentials(); + credentials.setApiKey("sky-ab123-abcd1234cdef1234abcd4321cdef4321"); + VaultConfig config = new VaultConfig(); + config.setVaultId(vaultID); + config.setClusterId(clusterID); + config.setEnv(Env.PROD); + config.setCredentials(credentials); + + VaultClient client = new VaultClient(config, credentials); + client.updateExecutorInHTTP(); + + OkHttpClient httpClient = (OkHttpClient) getPrivateField(client, "sharedHttpClient"); + Assert.assertNotNull(httpClient); + + ConnectionPool pool = httpClient.connectionPool(); + Assert.assertNotNull(pool); + + Object delegate = getPrivateField(pool, "delegate"); + int maxIdleConnections = (int) getPrivateField(delegate, "maxIdleConnections"); + Assert.assertEquals(10, maxIdleConnections); + } + + @Test + public void testConnectionPoolKeepAliveDuration() throws Exception { + Credentials credentials = new Credentials(); + credentials.setApiKey("sky-ab123-abcd1234cdef1234abcd4321cdef4321"); + VaultConfig config = new VaultConfig(); + config.setVaultId(vaultID); + config.setClusterId(clusterID); + config.setEnv(Env.PROD); + config.setCredentials(credentials); + + VaultClient client = new VaultClient(config, credentials); + client.updateExecutorInHTTP(); + + OkHttpClient httpClient = (OkHttpClient) getPrivateField(client, "sharedHttpClient"); + ConnectionPool pool = httpClient.connectionPool(); + + Object delegate = getPrivateField(pool, "delegate"); + long keepAliveDurationNs = (long) getPrivateField(delegate, "keepAliveDurationNs"); + Assert.assertEquals(TimeUnit.MINUTES.toNanos(1), keepAliveDurationNs); + } + + @Test + public void testSharedHttpClientIsSingleton() throws Exception { + Credentials credentials = new Credentials(); + credentials.setApiKey("sky-ab123-abcd1234cdef1234abcd4321cdef4321"); + VaultConfig config = new VaultConfig(); + config.setVaultId(vaultID); + config.setClusterId(clusterID); + config.setEnv(Env.PROD); + config.setCredentials(credentials); + + VaultClient client = new VaultClient(config, credentials); + client.updateExecutorInHTTP(); + OkHttpClient first = (OkHttpClient) getPrivateField(client, "sharedHttpClient"); + + client.updateExecutorInHTTP(); + OkHttpClient second = (OkHttpClient) getPrivateField(client, "sharedHttpClient"); + + Assert.assertSame("sharedHttpClient must not be rebuilt on repeated calls", first, second); + } + // Helper methods for reflection field access private Object getPrivateField(Object obj, String fieldName) throws Exception { java.lang.reflect.Field field = obj.getClass().getDeclaredField(fieldName); @@ -533,6 +601,57 @@ public void testPrioritiseCredentialsThrowsWhenNoCredentials() throws Exception } } + @Test + public void testApiClientNotReinitializedOnSubsequentCalls() throws Exception { + Credentials credentials = new Credentials(); + credentials.setApiKey("sky-ab123-abcd1234cdef1234abcd4321cdef4321"); + VaultConfig config = new VaultConfig(); + config.setVaultId(vaultID); + config.setClusterId(clusterID); + config.setEnv(Env.PROD); + config.setCredentials(credentials); + + VaultClient client = new VaultClient(config, credentials); + + client.setBearerToken(); + Object apiClientFirst = getPrivateField(client, "apiClient"); + Assert.assertNotNull(apiClientFirst); + + client.setBearerToken(); + Object apiClientSecond = getPrivateField(client, "apiClient"); + + Assert.assertSame("apiClient must not be rebuilt on repeated setBearerToken calls", apiClientFirst, apiClientSecond); + } + + @Test + public void testTokenUpdatedWithoutReinitializingApiClient() throws Exception { + Credentials credentials = new Credentials(); + credentials.setApiKey("sky-ab123-abcd1234cdef1234abcd4321cdef4321"); + VaultConfig config = new VaultConfig(); + config.setVaultId(vaultID); + config.setClusterId(clusterID); + config.setEnv(Env.PROD); + config.setCredentials(credentials); + + VaultClient client = new VaultClient(config, credentials); + client.setBearerToken(); + Object apiClientBefore = getPrivateField(client, "apiClient"); + String tokenBefore = (String) getPrivateField(client, "token"); + + // Rotate to a different API key + Credentials newCredentials = new Credentials(); + newCredentials.setApiKey("sky-ab123-1234567890abcdef1234567890abcdef"); + config.setCredentials(newCredentials); + + client.setBearerToken(); + Object apiClientAfter = getPrivateField(client, "apiClient"); + String tokenAfter = (String) getPrivateField(client, "token"); + + Assert.assertSame("apiClient must not be rebuilt after credential rotation", apiClientBefore, apiClientAfter); + Assert.assertNotEquals("token must reflect the new credential", tokenBefore, tokenAfter); + Assert.assertEquals("sky-ab123-1234567890abcdef1234567890abcdef", tokenAfter); + } + @Test public void testPrioritiseCredentialsSetsFinalCredentialsFromSysCredentials() throws Exception { VaultConfig config = new VaultConfig();