Compare different metrics performance on unbalanced datasets. The strength of the batch effect is equal and just batch proportions are changed. We test this on different overall batch strength (high, middle, no batch effect).
suppressPackageStartupMessages({
library(CellMixS)
library(purrr)
library(tidyr)
library(dplyr)
library(gridExtra)
library(scran)
library(cowplot)
library(jcolors)
library(ggpubr)
library(circlize)
library(viridis)
library(ComplexHeatmap)
library(stringr)
library(magrittr)
library(colorspace)
library(corrplot)
library(RColorBrewer)
library(hrbrthemes)
library(corrplot)
library(RColorBrewer)
library(ggforce)
})
options(bitmapType='cairo')
sce_whole_name <- unlist(c(strsplit(params$sce_name, ",")))
sce_name <- gsub("^un_[0-9].*_sim_", "", sce_whole_name) %>%
gsub("__[0-9]_sce.*", "", .) %>% unique()
metrics <- unlist(c(strsplit(params$metrics, ",")))
sce_sim_list <- lapply(sce_name, function(dataset){
sim_vec <- sce_whole_name[grepl(paste0(dataset, "_"), sce_whole_name)]
}) %>% set_names(sce_name)
sim_list <- lapply(sce_name, function(dataset){
sim_vec <- sce_whole_name[grepl(dataset, sce_whole_name)]
sim_vec <- gsub("^un_", "", sim_vec) %>% gsub("_sim.*", "", .) %>%
as.numeric()
}) %>% set_names(sce_name)
sce_list <- lapply(sce_name, function(dataset){
sim_vec <- paste0(params$sce, sce_sim_list[[dataset]], "_", params$last, "_sce.rds") %>%
set_names(sce_sim_list[[dataset]])
}) %>% set_names(sce_name)
out_path_cor <- params$out_cor
out_path_res <- params$out_res
out_path_fig <- params$fig_res
cols_data <-c(c(jcolors('pal6'),jcolors('pal8'))[c(1,8,14,5,2:4,6,7,9:13,15:20)],jcolors('pal4'))
names(cols_data) <- c()
cols <-c(c(jcolors('pal6'),jcolors('pal8'), jcolors('pal7'))[c(12,18,1,25,27,2,4,7,3,6,8,14,9,20)],jcolors('pal4'))
names(cols) <- c()
metric_vec <- gsub("_", ".", metrics)
metric_vec <- gsub("default", "", metric_vec)
dataset_sim_tab <- function(dataset){
sim_list <- lapply(sce_sim_list[[dataset]], function(sim_name){
sce <- readRDS(sce_list[[dataset]][[sim_name]])
metric_nam <- grep(paste(metric_vec, collapse="|"),
names(colData(sce)), value=TRUE)
metric_nam <- metric_nam[-grep("smooth", metric_nam)]
res <- as.data.frame(colData(sce)[, metric_nam])
res$cell <- colnames(sce)
res
}) %>% set_names(sce_sim_list[[dataset]])
sim_tab <- sim_list %>% bind_rows(.id = "dataset")
metric_vec[metric_vec %in% "cms."] <- "cms.batch_id"
sim_wide_list <- lapply(metric_vec, function(met){
met_col <- colnames(sim_tab)[grepl(paste0("^", met), colnames(sim_tab))]
metr_long <- sim_tab[, c("dataset", "cell", met_col)]
metr_long$dataset <- as.factor(metr_long$dataset)
metr_wide <- pivot_wider(metr_long, names_from = dataset,
values_from = all_of(met_col))
}) %>% set_names(metric_vec)
sim_wide_list
}
res_by_met <- lapply(sce_name, dataset_sim_tab) %>% set_names(sce_name)
#### ---------- Order by metric type ----------------------------------#######
#(manual needs to be adjusted if new metrics are added)
cms_ind <- grep("cms", names(res_by_met[[1]]))
lisi_ind <- grep("isi", names(res_by_met[[1]]))
ent_ind <- grep("entropy", names(res_by_met[[1]]))
mm_ind <- grep("mm", names(res_by_met[[1]]))
asw_ind <- grep("sw", names(res_by_met[[1]]))
kbet_ind <- grep("kbet", names(res_by_met[[1]]))
graph_ind <- grep("graph", names(res_by_met[[1]]))
pcr_ind <- grep("pcr", names(res_by_met[[1]]))
metric_order <- names(res_by_met[[1]])[c(cms_ind, lisi_ind, ent_ind, mm_ind,
kbet_ind, asw_ind, graph_ind, pcr_ind)]
names(cols) <- metric_order
cols <- cols[metric_order]
####--------------------------------------------------------------------########
mean_tab <- lapply(sce_name, function(dataset){
met_res <- res_by_met[[dataset]]
metric_vec[metric_vec %in% "cms."] <- "cms.batch_id"
met_mean <- lapply(metric_vec, function(met){
mean_vec <- met_res[[met]] %>% select(-cell) %>% as.matrix() %>% colMeans(., na.rm = TRUE)
}) %>% bind_cols() %>% set_colnames(metric_vec)
batch_vec <- colnames(met_res[[1]]) %>% str_match("_([0-9].*?)_")
batch_vec <- as.numeric(batch_vec[!is.na(batch_vec[, 2]), 2])
met_mean$unbalanced <- batch_vec
met_mean
}) %>% set_names(sce_name)
## scale means
# scale scores by centering around 0% unbalanced (fully balanced) and
# scaling to the metrics total range
scale_means <- lapply(names(mean_tab), function(dataset){
mean_res <- mean_tab[[dataset]]
#scaling vector of each method's range
sce <- readRDS(sce_list[[1]][1])
n_bids <- length(levels(as.factor(sce$batch_id)))
mm_ind <- grep("mm", colnames(mean_res))
lis_ind <- grep("isi", colnames(mean_res))
scal_vec <- rep(1, ncol(mean_res)-1)
scal_vec[mm_ind] <- 200
scal_vec[lis_ind] <- n_bids - 1
max_res <- mean_res[mean_res$unbalanced == 0, !colnames(mean_res) %in% "unbalanced"]
scale_res <- mean_res %>% select(-unbalanced) %>%
scale(., center = max_res, scale = scal_vec) %>% as.data.frame()
scale_res$unbalanced <- mean_res$unbalanced
scale_res
}) %>% set_names(names(mean_tab))
#get_descriptive_names
ind_no_batch <- grep("_0$", names(mean_tab))
ind_int_batch <- grep("_0.3$", names(mean_tab))
ind_strong_batch <- grep("_1$", names(mean_tab))
d_names <- names(mean_tab)
d_names[ind_no_batch] <- "no_batch_effect"
d_names[ind_int_batch] <- "moderate_batch_effect"
d_names[ind_strong_batch] <- "separated_batches"
names(scale_means) <- d_names
plot_trends <- function(dataset){
mean_res <- scale_means[[dataset]]
mean_long <- mean_res %>% pivot_longer(-unbalanced, names_to = "metric",
values_to = "scaled_result")
mean_long$metric <- factor(mean_long$metric,levels = metric_order)
ggplot(mean_long, aes(x = unbalanced, y = scaled_result)) +
geom_line(aes(color = metric)) +
geom_point(aes(color = metric)) +
theme_bw() +
scale_color_manual(values = cols) +
ggtitle(dataset)
}
template <- c(
"#### {{nm}}\n",
"```{r scaling{{nm}}, echo = FALSE}\n",
"plot_trends('{{nm}}')\n",
"```\n",
"\n"
)
plots <- lapply(names(scale_means),
function(nm) knitr::knit_expand(text = template)
)
Plot trends per metric
# Get lowest level of unbalance that deviates more than 5% from the balanced value.
limit_tab <- lapply(names(scale_means), function(dataset){
mean_res <- scale_means[[dataset]]
dev_lim <- function(x){
dev_all <- which(abs(x) > 0.05)
batch <- mean_res$unbalanced[dev_all] %>% min()
}
batch_lim <- apply(mean_res, 2, dev_lim)
}) %>% set_names(names(scale_means)) %>% bind_cols() %>%
mutate(metric = colnames(scale_means[[1]]))
## Warning in min(.): no non-missing arguments to min; returning Inf
## Warning in min(.): no non-missing arguments to min; returning Inf
## Warning in min(.): no non-missing arguments to min; returning Inf
## Warning in min(.): no non-missing arguments to min; returning Inf
## Warning in min(.): no non-missing arguments to min; returning Inf
## Warning in min(.): no non-missing arguments to min; returning Inf
## Warning in min(.): no non-missing arguments to min; returning Inf
## Warning in min(.): no non-missing arguments to min; returning Inf
## Warning in min(.): no non-missing arguments to min; returning Inf
## Warning in min(.): no non-missing arguments to min; returning Inf
limit_tab[limit_tab == Inf] <- 1
sep_trends <- function(dataset){
mean_res <- scale_means[[dataset]]
mean_long <- mean_res %>% pivot_longer(-unbalanced, names_to = "metric",
values_to = "scaled_result")
mean_long <- mean_long %>% mutate(metric2 = metric)
cols_rep <- rep(cols[1: ncol(mean_res)-1], each = nrow(mean_res))
mean_long$metric <- factor(mean_long$metric,levels = metric_order)
lim_ordered <- limit_tab %>% filter(metric %in% metric_order) %>%
arrange(match(metric, metric_order))
find_scal <- function(met){
lim <- lim_ordered[lim_ordered$metric %in% met, dataset]
lim_score <- mean_res[mean_res$unbalanced == lim[[1]], met]
}
mean_lim <- metric_order %>% map(find_scal) %>% unlist()
unb_limit <- data.frame("unbalanced" = lim_ordered[,dataset][[1]],
"metric" = metric_order,
"scaled_score" = mean_lim,
"label" = paste0("limit = ", lim_ordered[,dataset][[1]]))
p <- ggplot(mean_long, aes(x = unbalanced, y = scaled_result)) +
geom_line(data=mean_long %>% dplyr::select(-metric), aes(group=metric2),
color="grey", size=0.5, alpha=0.5) +
geom_line( aes(color=metric), color=cols_rep, size=1.2 ) +
scale_color_viridis(discrete = TRUE) +
theme_ipsum(base_family = 'Helvetica') +
theme(
legend.position="none",
plot.title = element_text(size=14),
panel.grid.minor = element_blank()
) +
geom_mark_circle(data = unb_limit, aes(label = label, x = unbalanced, y = scaled_score),
label.colour = "black", inherit.aes = FALSE,
label.fontsize = 20,
size = 1,
expand = unit(1, "mm"),
label.minwidth = unit(20, "mm")) +
ggtitle(dataset) +
facet_wrap(~metric)
p
}
template_sep <- c(
"#### {{nm}}\n",
"```{r scaling sep {{nm}}, echo = FALSE, fig.width = 8, fig.height = 7}\n",
"p <- sep_trends('{{nm}}')\n",
"p",
"saveRDS(p, paste0(out_path_fig, '_{{nm}}.rds'))",
"```\n",
"\n"
)
plots_sep <- lapply(names(scale_means),
function(nm) knitr::knit_expand(text = template_sep)
)
# Get lowest level of unbalance that deviates more than 5% from the balanced value.
limit_tab <- lapply(names(scale_means), function(dataset){
mean_res <- scale_means[[dataset]]
dev_lim <- function(x){
dev_all <- which(abs(x) > 0.05)
batch <- mean_res$unbalanced[dev_all] %>% min()
}
batch_lim <- apply(mean_res, 2, dev_lim)
}) %>% set_names(names(scale_means)) %>% bind_cols() %>%
mutate(metric = colnames(scale_means[[1]]))
limit_tab[limit_tab == Inf] <- 1
lim_long <- limit_tab %>% filter(!metric %in% "unbalanced") %>%
pivot_longer(-metric, names_to = "dataset", values_to = "unbalanced_limit")
## order by metrics type
lim_long$metric <- factor(lim_long$metric,levels = metric_order)
p <- ggplot(lim_long, aes(x = metric , y = unbalanced_limit, fill = dataset)) +
geom_boxplot(fill = cols[1:length(levels(as.factor(lim_long$metric)))], alpha = 0.5) +
geom_dotplot(binaxis='y', stackdir='center', dotsize=0.5) +
scale_fill_manual(values=cols_data) +
theme_ipsum(base_family = 'Helvetica') +
theme(axis.text.x = element_text(face="bold", size=10, angle=45))
p
## `stat_bindot()` using `bins = 30`. Pick better value with `binwidth`.
saveRDS(p, paste0(out_path_fig, "_limits.rds"))
#save mean results
mean_unb_lim <- lim_long %>% group_by(metric) %>%
summarise(unb_limit = mean(unbalanced_limit),
unb_limit_sd = sd(unbalanced_limit))
saveRDS(mean_unb_lim, out_path_res)
As lollipop plot
lim_long$metric <- recode(lim_long$metric, cms.kmin = "cms_kmin", cms.bmin = "cms_bmin",
cms.batch_id = "cms_default", graph.connectivity = "graph", kbet = "kBet")
names(cols) <- c()
p <- ggplot(lim_long) +
geom_segment( aes(x=dataset, xend=dataset, y=0, yend=unbalanced_limit), color="black") +
geom_point( aes(x=dataset, y=unbalanced_limit, color=metric, shape = dataset), size=4 ) +
#coord_flip()+
theme_ipsum(base_family = 'Helvetica') +
theme(
legend.position = "top",
panel.border = element_blank(),
panel.spacing = unit(0.1, "lines")
) +
scale_colour_manual(values = c(cols[1:length(levels(as.factor(lim_long$metric)))]), guide=FALSE) +
xlab("") +
ylab("Value of Y") +
facet_wrap(~metric, ncol=length(levels(as.factor(lim_long$metric))), strip.position = "bottom")
p
saveRDS(p, paste0(out_path_fig, "_limits_lolli.rds"))
#scale means
cor_met <- lapply(names(mean_tab), function(dataset){
mean_res <- mean_tab[[dataset]] %>% select(-unbalanced)
cor_scal <- cor(mean_res, use = "complete.obs", method = "spearman")
})
## Warning in cor(mean_res, use = "complete.obs", method = "spearman"): the
## standard deviation is zero
cor_met_mean <- Reduce(`+`, cor_met)/length(cor_met)
corrplot(cor_met_mean,
type="upper",
order="original",
hclust.method = "complete",
col=brewer.pal(n=8, name="PuOr"),
addgrid.col = NA,
addCoef.col = "black",
diag = FALSE)
#save correlation
saveRDS(cor_met_mean, out_path_cor)