Skip to content

Instantly share code, notes, and snippets.

@jrosell
Last active February 3, 2026 16:44
Show Gist options
  • Select an option

  • Save jrosell/48f25557faad075b0bbc601138732c22 to your computer and use it in GitHub Desktop.

Select an option

Save jrosell/48f25557faad075b0bbc601138732c22 to your computer and use it in GitHub Desktop.
Create and search image embeddings using torch, mobilenet v3, rchroma and R
---
title: "Create and search image embeddings using torch, mobilenet v3, rchroma and R"
---
Prepare torch, and mobilenet v3 embeddings.
```{r}
if (!rlang::is_installed("torch")) {
# https://torch.mlverse.org/docs/articles/installation#pre-built
options(timeout = 90000)
kind <- "cu128"
version <- available.packages()["torch","Version"]
options(repos = c(
torch = sprintf("https://torch-cdn.mlverse.org/packages/%s/%s/", kind, version),
CRAN = "https://cloud.r-project.org"
))
install.packages("torch")
torch::install_torch()
}
if (!rlang::is_installed("torchvision")) {
install.packages("torchvision")
}
library(torch)
library(torchvision)
model <- model_mobilenet_v3_small(pretrained = TRUE)
# model <- model_mobilenet_v3_large(pretrained = TRUE)
model$classifier <- nn_identity()
model$eval()
image_to_tensor <- function(img_url) {
batch <- base_loader(img_url) |>
transform_to_tensor() |>
transform_resize(c(224, 224)) |>
# transform_center_crop(c(224, 224)) |>
transform_normalize(
mean = c(0.485, 0.456, 0.406),
std = c(0.229, 0.224, 0.225)
) |>
torch_unsqueeze(1)
batch
}
extract_embedding <- function(img_url) {
img <- image_to_tensor(img_url)
with_no_grad({
emb <- model(img)
})
emb
}
l2_normalize <- function(x) {
x / torch_norm(x, dim = 2, keepdim = TRUE)
}
img1 <- "burger.jpg"
img2 <- "burger_crop.jpg"
emb1 <- extract_embedding(img1) |> l2_normalize()
emb2 <- extract_embedding(img2) |> l2_normalize()
similarity <- torch_cosine_similarity(emb1, emb2, dim = 2)
as.numeric(similarity)
```
Install, connect and create a collectionj to the chroma vector database.
```{r}
if (!rlang::is_installed("rchroma")) {
install.packages("rchroma")
}
library(rchroma)
chroma_docker_run()
client <- chroma_connect()
heartbeat(client)
version(client)
tryCatch(get_collection(client, "my_beans"), error = \(x) {
create_collection(client, "my_beans")
})
list_collections(client)
```
Get the data:
```{r}
download.file("https://huggingface.co/datasets/AI-Lab-Makerere/beans/resolve/main/data/train.zip", "beans_train.zip")
unzip("beans_train.zip", exdir = "beans")
download.file("https://huggingface.co/datasets/AI-Lab-Makerere/beans/resolve/main/data/validation.zip", "beans_validation.zip")
unzip("beans_validation.zip", exdir = "beans")
download.file("https://huggingface.co/datasets/AI-Lab-Makerere/beans/resolve/main/data/validation.zip", "beans_test.zip")
unzip("beans_test.zip", exdir = "beans")
```
Create the document with all the embeddings:
```{r}
library(fs)
library(purrr)
library(stringr)
docs <- path("beans", "train", "healthy") |>
dir_ls() |>
as.character() |>
keep(\(x) str_detect(x, ".jpg"))
ids <- paste0("doc", seq_along(docs))
embeddings <-
map(.progress = TRUE, docs, \(x) {
x |>
extract_embedding() |>
l2_normalize() |>
torch_squeeze() |> # Removes the dimension of size 1
as_array()
})
embeddings[1] |> readr::write_rds("beans/embed_0.rds")
```
Send it to chroma:
```{r}
add_documents(
client,
"my_beans",
documents = docs,
ids = ids,
embeddings = embeddings
)
```
Test to query it and retrieve the 4 most similar:
```{r}
first_embed <- readr::read_rds("beans/embed_0.rds")
# Query similar documents using embeddings
results <- query(
client,
"my_beans",
query_embeddings = first_embed,
n_results = 4
)
results$documents[[1]]
if (!rlang::is_installed("magick")) {
install.packages("magick")
}
library(magick)
img_paths <- unlist(results$documents[[1]])
img1 <- image_read(img_paths[1])
img2 <- image_read(img_paths[2])
img3 <- image_read(img_paths[3])
img4 <- image_read(img_paths[4])
comparison <- image_append(c(img1, img2, img3, img4))
plot(comparison)
```
@jrosell
Copy link
Author

jrosell commented Feb 3, 2026

beans_comparison

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment