Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions v3/src/main/java/com/skyflow/VaultClient.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package com.skyflow;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;

import com.skyflow.config.Credentials;
import com.skyflow.config.VaultConfig;
import com.skyflow.enums.UpsertType;
Expand All @@ -19,19 +23,17 @@
import com.skyflow.utils.Utils;
import com.skyflow.utils.logger.LogUtil;
import com.skyflow.utils.validations.Validations;
import com.skyflow.vault.data.DeleteTokensRequest;
import com.skyflow.vault.data.DetokenizeRequest;
import com.skyflow.vault.data.InsertRecord;
import com.skyflow.vault.data.DeleteTokensRequest;
import com.skyflow.vault.data.TokenizeRequest;
import com.skyflow.vault.data.TokenizeRecord;
import com.skyflow.vault.data.TokenizeRequest;
import io.github.cdimascio.dotenv.Dotenv;
import io.github.cdimascio.dotenv.DotenvException;
import okhttp3.ConnectionPool;
import okhttp3.OkHttpClient;
import okhttp3.Request;

import java.util.ArrayList;
import java.util.List;


public class VaultClient {
private final VaultConfig vaultConfig;
Expand All @@ -41,6 +43,8 @@ public class VaultClient {
private Credentials finalCredentials;
private String token;
private String apiKey;
private OkHttpClient sharedHttpClient = null;
private String currentVaultURL = null;

protected VaultClient(VaultConfig vaultConfig, Credentials credentials) throws SkyflowException {
super();
Expand Down Expand Up @@ -78,10 +82,12 @@ protected void setBearerToken() throws SkyflowException {
} else {
LogUtil.printInfoLog(InfoLogs.REUSE_BEARER_TOKEN.getLog());
}
updateExecutorInHTTP(); // update executor
this.apiClient = this.apiClientBuilder.build();
}

if (apiClient == null) {
updateExecutorInHTTP();
this.apiClient = this.apiClientBuilder.build();
}
}
private void updateVaultURL() throws SkyflowException {
// Fetch vaultURL from ENV
String vaultURL = Utils.getEnvVaultURL();
Expand All @@ -96,6 +102,10 @@ private void updateVaultURL() throws SkyflowException {
vaultURL = Utils.getVaultURL(this.vaultConfig.getClusterId(), this.vaultConfig.getEnv());
}
this.apiClientBuilder.url(vaultURL);
if (!vaultURL.equals(this.currentVaultURL)) {
this.currentVaultURL = vaultURL;
this.apiClient = null;
}
}

private void prioritiseCredentials() throws SkyflowException {
Expand Down Expand Up @@ -132,16 +142,18 @@ private void prioritiseCredentials() throws SkyflowException {
}

protected void updateExecutorInHTTP() {
OkHttpClient httpClient = new OkHttpClient.Builder()
.addInterceptor(chain -> {
Request original = chain.request();
Request requestWithAuth = original.newBuilder()
.header("Authorization", "Bearer " + this.token)
.build();
return chain.proceed(requestWithAuth);
})
.build();
apiClientBuilder.httpClient(httpClient);
if (sharedHttpClient == null) {
sharedHttpClient = new OkHttpClient.Builder()
.connectionPool(new ConnectionPool(10, 1, TimeUnit.MINUTES))
.addInterceptor(chain -> {
Request requestWithAuth = chain.request().newBuilder()
.header("Authorization", "Bearer " + this.token)
.build();
return chain.proceed(requestWithAuth);
})
.build();
apiClientBuilder.httpClient(sharedHttpClient);
}
}

protected V1InsertRequest getBulkInsertRequestBody(com.skyflow.vault.data.InsertRequest request, VaultConfig config) {
Expand Down
119 changes: 119 additions & 0 deletions v3/src/test/java/com/skyflow/VaultClientTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import static org.junit.Assert.*;
import com.skyflow.config.Credentials;
import com.skyflow.config.VaultConfig;
import okhttp3.ConnectionPool;
import okhttp3.OkHttpClient;
import java.util.concurrent.TimeUnit;
import com.skyflow.enums.Env;
import com.skyflow.enums.UpsertType;
import com.skyflow.errors.ErrorCode;
Expand Down Expand Up @@ -477,6 +480,71 @@ public void testUpdateExecutorInHTTP() throws Exception {
Assert.assertNotNull(recordsApi);
}

@Test
public void testConnectionPoolMaxIdleConnections() throws Exception {
Credentials credentials = new Credentials();
credentials.setApiKey("sky-ab123-abcd1234cdef1234abcd4321cdef4321");
VaultConfig config = new VaultConfig();
config.setVaultId(vaultID);
config.setClusterId(clusterID);
config.setEnv(Env.PROD);
config.setCredentials(credentials);

VaultClient client = new VaultClient(config, credentials);
client.updateExecutorInHTTP();

OkHttpClient httpClient = (OkHttpClient) getPrivateField(client, "sharedHttpClient");
Assert.assertNotNull(httpClient);

ConnectionPool pool = httpClient.connectionPool();
Assert.assertNotNull(pool);

Object delegate = getPrivateField(pool, "delegate");
int maxIdleConnections = (int) getPrivateField(delegate, "maxIdleConnections");
Assert.assertEquals(10, maxIdleConnections);
}

@Test
public void testConnectionPoolKeepAliveDuration() throws Exception {
Credentials credentials = new Credentials();
credentials.setApiKey("sky-ab123-abcd1234cdef1234abcd4321cdef4321");
VaultConfig config = new VaultConfig();
config.setVaultId(vaultID);
config.setClusterId(clusterID);
config.setEnv(Env.PROD);
config.setCredentials(credentials);

VaultClient client = new VaultClient(config, credentials);
client.updateExecutorInHTTP();

OkHttpClient httpClient = (OkHttpClient) getPrivateField(client, "sharedHttpClient");
ConnectionPool pool = httpClient.connectionPool();

Object delegate = getPrivateField(pool, "delegate");
long keepAliveDurationNs = (long) getPrivateField(delegate, "keepAliveDurationNs");
Assert.assertEquals(TimeUnit.MINUTES.toNanos(1), keepAliveDurationNs);
}

@Test
public void testSharedHttpClientIsSingleton() throws Exception {
Credentials credentials = new Credentials();
credentials.setApiKey("sky-ab123-abcd1234cdef1234abcd4321cdef4321");
VaultConfig config = new VaultConfig();
config.setVaultId(vaultID);
config.setClusterId(clusterID);
config.setEnv(Env.PROD);
config.setCredentials(credentials);

VaultClient client = new VaultClient(config, credentials);
client.updateExecutorInHTTP();
OkHttpClient first = (OkHttpClient) getPrivateField(client, "sharedHttpClient");

client.updateExecutorInHTTP();
OkHttpClient second = (OkHttpClient) getPrivateField(client, "sharedHttpClient");

Assert.assertSame("sharedHttpClient must not be rebuilt on repeated calls", first, second);
}

// Helper methods for reflection field access
private Object getPrivateField(Object obj, String fieldName) throws Exception {
java.lang.reflect.Field field = obj.getClass().getDeclaredField(fieldName);
Expand Down Expand Up @@ -533,6 +601,57 @@ public void testPrioritiseCredentialsThrowsWhenNoCredentials() throws Exception
}
}

@Test
public void testApiClientNotReinitializedOnSubsequentCalls() throws Exception {
Credentials credentials = new Credentials();
credentials.setApiKey("sky-ab123-abcd1234cdef1234abcd4321cdef4321");
VaultConfig config = new VaultConfig();
config.setVaultId(vaultID);
config.setClusterId(clusterID);
config.setEnv(Env.PROD);
config.setCredentials(credentials);

VaultClient client = new VaultClient(config, credentials);

client.setBearerToken();
Object apiClientFirst = getPrivateField(client, "apiClient");
Assert.assertNotNull(apiClientFirst);

client.setBearerToken();
Object apiClientSecond = getPrivateField(client, "apiClient");

Assert.assertSame("apiClient must not be rebuilt on repeated setBearerToken calls", apiClientFirst, apiClientSecond);
}

@Test
public void testTokenUpdatedWithoutReinitializingApiClient() throws Exception {
Credentials credentials = new Credentials();
credentials.setApiKey("sky-ab123-abcd1234cdef1234abcd4321cdef4321");
VaultConfig config = new VaultConfig();
config.setVaultId(vaultID);
config.setClusterId(clusterID);
config.setEnv(Env.PROD);
config.setCredentials(credentials);

VaultClient client = new VaultClient(config, credentials);
client.setBearerToken();
Object apiClientBefore = getPrivateField(client, "apiClient");
String tokenBefore = (String) getPrivateField(client, "token");

// Rotate to a different API key
Credentials newCredentials = new Credentials();
newCredentials.setApiKey("sky-ab123-1234567890abcdef1234567890abcdef");
config.setCredentials(newCredentials);

client.setBearerToken();
Object apiClientAfter = getPrivateField(client, "apiClient");
String tokenAfter = (String) getPrivateField(client, "token");

Assert.assertSame("apiClient must not be rebuilt after credential rotation", apiClientBefore, apiClientAfter);
Assert.assertNotEquals("token must reflect the new credential", tokenBefore, tokenAfter);
Assert.assertEquals("sky-ab123-1234567890abcdef1234567890abcdef", tokenAfter);
}

@Test
public void testPrioritiseCredentialsSetsFinalCredentialsFromSysCredentials() throws Exception {
VaultConfig config = new VaultConfig();
Expand Down
Loading