diff --git a/core/src/main/kotlin/org/evomaster/core/problem/rest/classifier/probabilistic/AbstractProbabilistic400EndpointModel.kt b/core/src/main/kotlin/org/evomaster/core/problem/rest/classifier/probabilistic/AbstractProbabilistic400EndpointModel.kt index dbf1972f6a..bc52e2665d 100644 --- a/core/src/main/kotlin/org/evomaster/core/problem/rest/classifier/probabilistic/AbstractProbabilistic400EndpointModel.kt +++ b/core/src/main/kotlin/org/evomaster/core/problem/rest/classifier/probabilistic/AbstractProbabilistic400EndpointModel.kt @@ -82,6 +82,12 @@ abstract class AbstractProbabilistic400EndpointModel( return } + if (modelKeys != null && dimension != null) { + require(warmup > 0) { "Warmup must be positive" } + initialized = true + return + } + val encoder = InputEncoderUtilWrapper(input, encoderType = encoderType) val allParamsPathsAndEncodedValues = encoder.getAllParamsPathsAndEncodedValues() diff --git a/core/src/main/kotlin/org/evomaster/core/problem/rest/classifier/probabilistic/InputEncoderUtilWrapper.kt b/core/src/main/kotlin/org/evomaster/core/problem/rest/classifier/probabilistic/InputEncoderUtilWrapper.kt index ede4265ec5..de642d89a2 100644 --- a/core/src/main/kotlin/org/evomaster/core/problem/rest/classifier/probabilistic/InputEncoderUtilWrapper.kt +++ b/core/src/main/kotlin/org/evomaster/core/problem/rest/classifier/probabilistic/InputEncoderUtilWrapper.kt @@ -23,6 +23,7 @@ import org.evomaster.core.search.gene.numeric.NumberGene import org.evomaster.core.search.gene.string.StringGene import kotlin.math.sqrt import kotlin.reflect.KClass +import java.util.IdentityHashMap /** * Utility object for encoding the genes of a [org.evomaster.core.problem.rest.data.RestCallAction] into a numerical representation. @@ -79,24 +80,22 @@ class InputEncoderUtilWrapper( * Builds a string representing the gene name and all its parents. * This string is used as a unique identifier for the gene in the AI models. */ - private fun genePath(g: Gene): String { - - val names = mutableListOf() + private val _genePathCache = IdentityHashMap() - var current: Gene? = g - - while (current != null) { - names.add(current.name) - current = current.parent as? Gene + private fun genePath(g: Gene): String { + return _genePathCache.getOrPut(g) { + val names = mutableListOf() + var current: Gene? = g + while (current != null) { + names.add(current.name) + current = current.parent as? Gene + } + val path = names.reversed() + if (path.size > 1) + path.dropLast(1).joinToString("/") + else + path.joinToString("/") } - - val path = names.reversed() - - return if (path.size > 1) - path.dropLast(1).joinToString("/") //ignore the last name, which is the repetition of gene itself as its own parent - else - path.joinToString("/") - } /** @@ -127,8 +126,9 @@ class InputEncoderUtilWrapper( */ fun getAllParamsPathsAndEncodedValues(): Map { - val paramPaths = endPointToGeneList().map { it.paramPath } - val encodedValues = encode() + val geneList = endPointToGeneList() + val paramPaths = geneList.map { it.paramPath } + val encodedValues = encode(geneList) return paramPaths.zip(encodedValues).toMap() } @@ -186,10 +186,10 @@ class InputEncoderUtilWrapper( * * @return a list of doubles representing the encoded feature vector */ - fun encode(): List { + fun encode(geneList: List = endPointToGeneList()): List { val sentinel = -1e6 // for null handling val neutral = 0.0 // for handling unsupported genes - val listGenes = endPointToGeneList().map { it.gene } + val listGenes = geneList.map { it.gene } val rawEncodedFeatures = mutableListOf() for (g in listGenes) {