Last active
February 3, 2026 16:44
-
-
Save jrosell/48f25557faad075b0bbc601138732c22 to your computer and use it in GitHub Desktop.
Create and search image embeddings using torch, mobilenet v3, rchroma and R
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| --- | |
| title: "Create and search image embeddings using torch, mobilenet v3, rchroma and R" | |
| --- | |
| Prepare torch, and mobilenet v3 embeddings. | |
| ```{r} | |
| if (!rlang::is_installed("torch")) { | |
| # https://torch.mlverse.org/docs/articles/installation#pre-built | |
| options(timeout = 90000) | |
| kind <- "cu128" | |
| version <- available.packages()["torch","Version"] | |
| options(repos = c( | |
| torch = sprintf("https://torch-cdn.mlverse.org/packages/%s/%s/", kind, version), | |
| CRAN = "https://cloud.r-project.org" | |
| )) | |
| install.packages("torch") | |
| torch::install_torch() | |
| } | |
| if (!rlang::is_installed("torchvision")) { | |
| install.packages("torchvision") | |
| } | |
| library(torch) | |
| library(torchvision) | |
| model <- model_mobilenet_v3_small(pretrained = TRUE) | |
| # model <- model_mobilenet_v3_large(pretrained = TRUE) | |
| model$classifier <- nn_identity() | |
| model$eval() | |
| image_to_tensor <- function(img_url) { | |
| batch <- base_loader(img_url) |> | |
| transform_to_tensor() |> | |
| transform_resize(c(224, 224)) |> | |
| # transform_center_crop(c(224, 224)) |> | |
| transform_normalize( | |
| mean = c(0.485, 0.456, 0.406), | |
| std = c(0.229, 0.224, 0.225) | |
| ) |> | |
| torch_unsqueeze(1) | |
| batch | |
| } | |
| extract_embedding <- function(img_url) { | |
| img <- image_to_tensor(img_url) | |
| with_no_grad({ | |
| emb <- model(img) | |
| }) | |
| emb | |
| } | |
| l2_normalize <- function(x) { | |
| x / torch_norm(x, dim = 2, keepdim = TRUE) | |
| } | |
| img1 <- "burger.jpg" | |
| img2 <- "burger_crop.jpg" | |
| emb1 <- extract_embedding(img1) |> l2_normalize() | |
| emb2 <- extract_embedding(img2) |> l2_normalize() | |
| similarity <- torch_cosine_similarity(emb1, emb2, dim = 2) | |
| as.numeric(similarity) | |
| ``` | |
| Install, connect and create a collectionj to the chroma vector database. | |
| ```{r} | |
| if (!rlang::is_installed("rchroma")) { | |
| install.packages("rchroma") | |
| } | |
| library(rchroma) | |
| chroma_docker_run() | |
| client <- chroma_connect() | |
| heartbeat(client) | |
| version(client) | |
| tryCatch(get_collection(client, "my_beans"), error = \(x) { | |
| create_collection(client, "my_beans") | |
| }) | |
| list_collections(client) | |
| ``` | |
| Get the data: | |
| ```{r} | |
| download.file("https://huggingface.co/datasets/AI-Lab-Makerere/beans/resolve/main/data/train.zip", "beans_train.zip") | |
| unzip("beans_train.zip", exdir = "beans") | |
| download.file("https://huggingface.co/datasets/AI-Lab-Makerere/beans/resolve/main/data/validation.zip", "beans_validation.zip") | |
| unzip("beans_validation.zip", exdir = "beans") | |
| download.file("https://huggingface.co/datasets/AI-Lab-Makerere/beans/resolve/main/data/validation.zip", "beans_test.zip") | |
| unzip("beans_test.zip", exdir = "beans") | |
| ``` | |
| Create the document with all the embeddings: | |
| ```{r} | |
| library(fs) | |
| library(purrr) | |
| library(stringr) | |
| docs <- path("beans", "train", "healthy") |> | |
| dir_ls() |> | |
| as.character() |> | |
| keep(\(x) str_detect(x, ".jpg")) | |
| ids <- paste0("doc", seq_along(docs)) | |
| embeddings <- | |
| map(.progress = TRUE, docs, \(x) { | |
| x |> | |
| extract_embedding() |> | |
| l2_normalize() |> | |
| torch_squeeze() |> # Removes the dimension of size 1 | |
| as_array() | |
| }) | |
| embeddings[1] |> readr::write_rds("beans/embed_0.rds") | |
| ``` | |
| Send it to chroma: | |
| ```{r} | |
| add_documents( | |
| client, | |
| "my_beans", | |
| documents = docs, | |
| ids = ids, | |
| embeddings = embeddings | |
| ) | |
| ``` | |
| Test to query it and retrieve the 4 most similar: | |
| ```{r} | |
| first_embed <- readr::read_rds("beans/embed_0.rds") | |
| # Query similar documents using embeddings | |
| results <- query( | |
| client, | |
| "my_beans", | |
| query_embeddings = first_embed, | |
| n_results = 4 | |
| ) | |
| results$documents[[1]] | |
| if (!rlang::is_installed("magick")) { | |
| install.packages("magick") | |
| } | |
| library(magick) | |
| img_paths <- unlist(results$documents[[1]]) | |
| img1 <- image_read(img_paths[1]) | |
| img2 <- image_read(img_paths[2]) | |
| img3 <- image_read(img_paths[3]) | |
| img4 <- image_read(img_paths[4]) | |
| comparison <- image_append(c(img1, img2, img3, img4)) | |
| plot(comparison) | |
| ``` |
Author
jrosell
commented
Feb 3, 2026
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment