Nvidia Merlin fast inference

I would like to train a TwoTowerModel and think of it more as an encoder rather than a retrieval model. Meaning, that I would like to do the retrieval separately using Redis ANN search instead using FAISS or similar in memory on a triton server.

I see the system I want to build look like this

  • Deploy an item-tower (encoder) to a triton-server
    • this encoder will receive JSON that has a single item’s features (i.e. category, price, ect…)
    • this encoder will return an n-dimensional item embedding
    • these embeddings will be saved to a Redis instance in an vector index (flat or hnsw)
  • Deploy a query-tower (encoder) to a triton-server
    • This encoder will receive JSON that has a single user’s features i.e. (state, gender, ect…)
    • this tower will return an n-dimensional user embedding
  • The returned user embedding will then be used for an ANN search with Redis and it will return the k-most relevant items for that user. Redis also has some hybrid search features that I would like to use.

I have built a go-frame work around the above already w/ a semantic embedding model instead of a recommender system and it works really well.

One downfall of Nvidia Merlin that I’m seeing so far is that inference on a single tower can be slow. I am looking for speeds in the range of 10-100ms. However, when I do an encode on a single user from my user_features, I’m seeing inference of a couple seconds. Is what I’m explaining possible with Merlin, or am I doing something wrong? I am running on GPU, but even on a CPU I wouldn’t expect multiple second embedding for a single user.

Here is some sample code.

## define model
user_schema = schema.select_by_tag(merlin.schema.Tags.USER)
user_inputs = ml.InputBlockV2(user_schema)
query = ml.Encoder(user_inputs, ml.MLPBlock(arch,
                                            normalization='batch_norm' if batch_norm else None,

item_schema = schema.select_by_tag(merlin.schema.Tags.ITEM)
item_inputs = ml.InputBlockV2(item_schema)
candidate = ml.Encoder(item_inputs, ml.MLPBlock(arch,
                                                normalization='batch_norm' if batch_norm else None,

model = ml.TwoTowerModelV2(
    query, candidate

## compile model
opt = tf.optimizers.Adam(learning_rate=1e-2)
model.compile(optimizer=opt, metrics=[ml.RecallAt(k=topk)])

## prep data for training
train_transformed_new = nvt.Dataset('train', engine='parquet', schema=valid_schema)
valid_transformed_new = nvt.Dataset('valid', engine='parquet', schema=valid_schema)

candidate_features = merlin.models.utils.dataset.unique_rows_by_features(train_transformed, merlin.schema.Tags.ITEM, merlin.schema.Tags.ITEM_ID)
train_eval_loader = ml.Loader(train_transformed_new, batch_size=2048).map(ml.ToTarget(valid_schema, TARGET))
valid_eval_loader = ml.Loader(valid_transformed_new, batch_size=2048).map(ml.ToTarget(valid_schema, TARGET))

    batch_size=4096, epochs=2

query_tower = model.query_encoder

queries = Dataset(

query_embeddings = query_tower.encode(