Compare different metrics performance on unbalanced datasets. The strength of the batch effect is equal and just batch proportions are changed. We test this on different overall batch strength (high, middle, no batch effect).

suppressPackageStartupMessages({
  library(CellMixS)
  library(purrr)
  library(tidyr)
  library(dplyr)
  library(gridExtra)
  library(scran)
  library(cowplot)
  library(jcolors)
  library(ggpubr)
  library(circlize)
  library(viridis)
  library(ComplexHeatmap)
  library(stringr)
  library(magrittr)
  library(colorspace)
    library(corrplot)
    library(RColorBrewer)
  library(hrbrthemes)
  library(corrplot)
  library(RColorBrewer)
  library(ggforce)
})

options(bitmapType='cairo')

Dataset and metrics

sce_whole_name <- unlist(c(strsplit(params$sce_name, ",")))
sce_name <- gsub("^un_[0-9].*_sim_", "", sce_whole_name) %>% 
    gsub("__[0-9]_sce.*", "", .) %>% unique()

metrics <- unlist(c(strsplit(params$metrics, ",")))

sce_sim_list <- lapply(sce_name, function(dataset){
    sim_vec <- sce_whole_name[grepl(paste0(dataset, "_"), sce_whole_name)]
}) %>% set_names(sce_name)

sim_list <- lapply(sce_name, function(dataset){
    sim_vec <- sce_whole_name[grepl(dataset, sce_whole_name)]
    sim_vec <- gsub("^un_", "", sim_vec) %>% gsub("_sim.*", "", .) %>% 
        as.numeric()
}) %>% set_names(sce_name)


sce_list <- lapply(sce_name, function(dataset){
  sim_vec <- paste0(params$sce, sce_sim_list[[dataset]], "_", params$last, "_sce.rds") %>%
             set_names(sce_sim_list[[dataset]])
}) %>% set_names(sce_name)

out_path_cor <- params$out_cor
out_path_res <- params$out_res
out_path_fig <- params$fig_res

cols_data <-c(c(jcolors('pal6'),jcolors('pal8'))[c(1,8,14,5,2:4,6,7,9:13,15:20)],jcolors('pal4'))
names(cols_data) <- c()

cols <-c(c(jcolors('pal6'),jcolors('pal8'), jcolors('pal7'))[c(12,18,1,25,27,2,4,7,3,6,8,14,9,20)],jcolors('pal4'))
names(cols) <- c()

Plot deviation from mean per metric

Metrics deviation

# Get lowest level of unbalance that deviates more than 5% from the balanced value. 
limit_tab <- lapply(names(scale_means), function(dataset){
    mean_res <- scale_means[[dataset]]
    dev_lim <- function(x){
        dev_all <- which(abs(x) > 0.05)
        batch <- mean_res$unbalanced[dev_all] %>% min()
    }
    batch_lim <- apply(mean_res, 2, dev_lim)
}) %>% set_names(names(scale_means)) %>% bind_cols() %>% 
  mutate(metric = colnames(scale_means[[1]]))

limit_tab[limit_tab == Inf] <- 1 

lim_long <- limit_tab %>% filter(!metric %in% "unbalanced") %>% 
    pivot_longer(-metric, names_to = "dataset", values_to = "unbalanced_limit")

## order by metrics type
lim_long$metric <- factor(lim_long$metric,levels = metric_order)

p <- ggplot(lim_long, aes(x = metric , y = unbalanced_limit, fill = dataset)) + 
    geom_boxplot(fill = cols[1:length(levels(as.factor(lim_long$metric)))], alpha = 0.5) + 
    geom_dotplot(binaxis='y', stackdir='center', dotsize=0.5) +
    scale_fill_manual(values=cols_data) + 
    theme_ipsum(base_family = 'Helvetica') +
    theme(axis.text.x = element_text(face="bold", size=10, angle=45))

p
## `stat_bindot()` using `bins = 30`. Pick better value with `binwidth`.

saveRDS(p, paste0(out_path_fig, "_limits.rds"))

#save mean results
mean_unb_lim <- lim_long %>% group_by(metric) %>% 
  summarise(unb_limit = mean(unbalanced_limit),
            unb_limit_sd = sd(unbalanced_limit))

saveRDS(mean_unb_lim, out_path_res)

As lollipop plot

lim_long$metric <- recode(lim_long$metric, cms.kmin = "cms_kmin", cms.bmin = "cms_bmin",
       cms.batch_id = "cms_default", graph.connectivity = "graph", kbet = "kBet")
names(cols) <- c()

p <- ggplot(lim_long) +
  geom_segment( aes(x=dataset, xend=dataset, y=0, yend=unbalanced_limit), color="black") +
  geom_point( aes(x=dataset, y=unbalanced_limit, color=metric, shape = dataset), size=4 ) +
  #coord_flip()+
  theme_ipsum(base_family = 'Helvetica') +
  theme(
    legend.position = "top",
    panel.border = element_blank(),
    panel.spacing = unit(0.1, "lines")
  ) +
  scale_colour_manual(values = c(cols[1:length(levels(as.factor(lim_long$metric)))]), guide=FALSE) +
  xlab("") +
  ylab("Value of Y") +
  facet_wrap(~metric, ncol=length(levels(as.factor(lim_long$metric))), strip.position = "bottom")

p

saveRDS(p, paste0(out_path_fig, "_limits_lolli.rds"))

Correlation between metrics

#scale means
cor_met <- lapply(names(mean_tab), function(dataset){
    mean_res <- mean_tab[[dataset]] %>% select(-unbalanced)
    cor_scal <- cor(mean_res, use = "complete.obs", method = "spearman") 
})
## Warning in cor(mean_res, use = "complete.obs", method = "spearman"): the
## standard deviation is zero
cor_met_mean <- Reduce(`+`, cor_met)/length(cor_met)

corrplot(cor_met_mean, 
         type="upper", 
         order="original",
         hclust.method = "complete",
         col=brewer.pal(n=8, name="PuOr"),
         addgrid.col = NA,
         addCoef.col = "black",
         diag = FALSE)

#save correlation
saveRDS(cor_met_mean, out_path_cor)