suppressPackageStartupMessages({
library(scater)
library(CellMixS)
library(purrr)
library(tidyr)
library(dplyr)
library(gridExtra)
library(scran)
library(cowplot)
library(jcolors)
library(ggpubr)
library(stringr)
library(viridis)
library(scales)
library(hrbrthemes)
library(magrittr)
library(forcats)
library(corrplot)
library(RColorBrewer)
library(ggforce)
})
options(bitmapType='cairo')
last <- params$last
sce_whole_name <- unlist(c(strsplit(params$sce_name, ",")))
sce_name <- gsub("^sim_", "", sce_whole_name) %>%
gsub("_[0-9].*", "", .) %>% unique()
metrics <- unlist(c(strsplit(params$metrics, ",")))
sce_sim_list <- lapply(sce_name, function(dataset){
sim_vec <- sce_whole_name[grepl(dataset, sce_whole_name)]
}) %>% set_names(sce_name)
sce_list <- lapply(sce_name, function(dataset){
sim_vec <- paste0(params$sce, sce_sim_list[[dataset]], "_", last, "_sce.rds") %>%
set_names(sce_sim_list[[dataset]])
}) %>% set_names(sce_name)
meta_list <- lapply(sce_name, function(dataset){
sim_vec <- paste0(params$meta, sce_sim_list[[dataset]], "_meta.rds") %>%
set_names(sce_sim_list[[dataset]])
}) %>% set_names(sce_name)
out_path_cor <- params$out_cor
out_path_res <- params$out_res
out_path_fig <- params$fig_res
cols_data <-c(c(jcolors('pal6'),jcolors('pal8'))[c(1,8,14,5,2:4,6,7,9:13,15:20)],jcolors('pal4'))
names(cols_data) <- c()
cols <-c(c(jcolors('pal6'),jcolors('pal8'), jcolors('pal7'))[c(1,12,18,25,27,2,4,7,3,6,8,14,9,20)],jcolors('pal4'))
names(cols) <- c()
How do the metric scores scale with simulated batch effects that are gradually increased?
metric_vec <- gsub("_", ".", metrics)
metric_vec <- gsub("default", "", metric_vec)
dataset_sim_tab <- function(dataset){
sim_list <- lapply(sce_sim_list[[dataset]], function(sim_name){
sce <- readRDS(sce_list[[dataset]][[sim_name]])
meta <- readRDS(meta_list[[dataset]][[sim_name]])
metric_nam <- grep(paste(metric_vec, collapse="|"),
names(colData(sce)), value=TRUE)
metric_nam <- metric_nam[-grep("smooth", metric_nam)]
res <- as.data.frame(colData(sce)[, metric_nam])
res$cell <- colnames(sce)
res
}) %>% set_names(sce_sim_list[[dataset]])
sim_tab <- sim_list %>% bind_rows(.id = "dataset")
metric_vec[metric_vec %in% "cms."] <- "cms.batch_id"
sim_wide_list <- lapply(metric_vec, function(met){
met_col <- colnames(sim_tab)[grepl(paste0("^", met), colnames(sim_tab))]
metr_long <- sim_tab[, c("dataset", "cell", met_col)]
metr_long$dataset <- as.factor(metr_long$dataset)
metr_wide <- pivot_wider(metr_long, names_from = dataset,
values_from = met_col)
}) %>% set_names(metric_vec)
}
res_by_met <- lapply(sce_name, dataset_sim_tab) %>% set_names(sce_name)
#### ---------- Order by metric type ----------------------------------#######
#(manual needs to be adjusted if new metrics are added)
cms_ind <- grep("cms", names(res_by_met[[1]]))
lisi_ind <- grep("isi", names(res_by_met[[1]]))
ent_ind <- grep("entropy", names(res_by_met[[1]]))
mm_ind <- grep("mm", names(res_by_met[[1]]))
asw_ind <- grep("sw", names(res_by_met[[1]]))
kbet_ind <- grep("kbet", names(res_by_met[[1]]))
graph_ind <- grep("graph", names(res_by_met[[1]]))
pcr_ind <- grep("pcr", names(res_by_met[[1]]))
metric_order <- names(res_by_met[[1]])[c(cms_ind, lisi_ind, ent_ind, mm_ind,
kbet_ind, asw_ind, graph_ind, pcr_ind)]
names(cols) <- names(res_by_met[[1]])
#cols <- cols[names(res_by_met[[1]])]
####--------------------------------------------------------------------########
mean_tab <- lapply(sce_name, function(dataset){
met_res <- res_by_met[[dataset]]
metric_vec[metric_vec %in% "cms."] <- "cms.batch_id"
met_mean <- lapply(metric_vec, function(met){
mean_vec <- met_res[[met]] %>% select(-cell) %>% as.matrix() %>% colMeans(., na.rm = TRUE)
}) %>% bind_cols() %>% set_colnames(metric_vec)
batch_vec <- colnames(met_res[[1]]) %>% str_match("_([0-9].*?)_")
batch_vec <- as.numeric(batch_vec[!is.na(batch_vec[, 2]), 2])
met_mean$batch <- batch_vec
met_mean
}) %>% set_names(sce_name)
#scale means
scale_means <- lapply(names(mean_tab), function(dataset){
mean_res <- mean_tab[[dataset]]
dir_ind <- which(!colnames(mean_res) %in% c("batch", "mm", "kbet", "pcr"))
mean_res[, dir_ind] <- mean_res[, dir_ind] * -1
scale_res <- mean_res %>% select(-batch) %>% scale() %>% as.data.frame()
scale_res$batch <- mean_res$batch
scale_res
}) %>% set_names(names(mean_tab))
plot_trends <- function(dataset){
mean_res <- scale_means[[dataset]]
mean_long <- mean_res %>% pivot_longer(-batch, names_to = "metric",
values_to = "scaled_score")
mean_long$metric <- factor(mean_long$metric,levels = metric_order)
ggplot(mean_long, aes(x = batch, y = scaled_score)) +
geom_line(aes(color = metric)) +
geom_point(aes(color = metric)) +
theme_bw() +
scale_color_manual(values = cols) +
ggtitle(dataset)
}
template <- c(
"#### {{nm}}\n",
"```{r scaling{{nm}}, echo = FALSE}\n",
"plot_trends('{{nm}}')\n",
"```\n",
"\n"
)
plots <- lapply(names(scale_means),
function(nm) knitr::knit_expand(text = template)
)
To compare the metrics limits and their relative correlation with batch strength/size we define an upper limit at 95% of the maximal metrics score and a minimal limit at the minimal metrics score and use the corresponding batch size factor as batch limits. We check the metrics correlation with the batch size factor between it’s limits. Be aware that these limits are dataset specific and do not represent fixed limits. They can only be used to compare metrics, but are not informative about the absolut metric’s sensitivity.
limit_list <- lapply(names(mean_tab), function(dataset){
mean_res <- mean_tab[[dataset]]
lim_rev <- function(x){
percent <- (max(x) - min(x))/100
lower <- which(x < (max(x) - 10 * percent)) # smallest value above 0.95 * batch0
upper <- which(x > (min(x) + 10 * percent))
upper_lim <- which(x < (min(x) + 1 * percent))
lower_lim <- which(x > (max(x) - 1 * percent))
min_bat <- min(mean_res$batch[which(x == max(x))])
max_bat <- min(mean_res$batch[which(x == min(x))])
res_low <- min(mean_res$batch[lower][mean_res$batch[lower] > min_bat])
res_up <- min(mean_res$batch[upper][mean_res$batch[upper] > max_bat])
res_up_limit <- min(mean_res$batch[upper_lim])
res_low_limit <- min(mean_res$batch[lower_lim])
batch <- c(res_low, res_up, res_up_limit, res_low_limit)
}
batch_lim <- apply(mean_res, 2, lim_rev) %>% set_rownames(c("lower", "upper", "limit", "limit2"))
batch_lim[,c("kbet", "mm", "pcr")] <-
batch_lim[c(2,1,4,3), c("kbet", "mm", "pcr")]
batch_lim
}) %>% set_names(names(mean_tab))
lim_fun <- function(dataset){
lim_res <- limit_list[[dataset]]
lim_low <- lim_res["lower",] %>% set_names(colnames(lim_res))
}
limit_tab <- names(limit_list) %>% map(lim_fun) %>%
set_names(names(limit_list)) %>% bind_cols() %>%
mutate(metric = colnames(limit_list[[1]]))
lim_long <- limit_tab %>% filter(!metric %in% "batch") %>%
pivot_longer(-metric, names_to = "dataset", values_to = "batch_limit")
cor_list <- lapply(names(mean_tab), function(dataset){
mean_res <- mean_tab[[dataset]]
limit_res <- limit_list[[dataset]]
cor_fun <- function(metric){
cor_res <- mean_res %>%
filter(batch <= limit_res["limit", metric]) %>%
select(metric, "batch") %>% cor(method ="pearson") %>% min() %>% abs()
#%>% filter(batch >= limit_res["lower", metric])
}
cor_tab <- map(colnames(mean_res), cor_fun) %>% bind_cols() %>%
set_colnames(colnames(mean_res))
}) %>% bind_rows() %>% set_rownames(names(mean_tab))
cor_long <- cor_list %>% mutate(dataset = rownames(.)) %>% select(-batch) %>%
pivot_longer(-dataset, names_to = "metric", values_to = "correlation")
cor_list_all <- lapply(names(mean_tab), function(dataset){
mean_res <- mean_tab[[dataset]]
limit_res <- limit_list[[dataset]]
cor_fun <- function(metric){
cor_res <- mean_res %>%
select(metric, "batch") %>% cor(method ="spearman") %>% min() %>% abs()
}
cor_tab <- map(colnames(mean_res), cor_fun) %>% bind_cols() %>%
set_colnames(colnames(mean_res))
}) %>% bind_rows() %>% set_rownames(names(mean_tab))
cor_long_all <- cor_list_all %>% mutate(dataset = rownames(.)) %>% select(-batch) %>%
pivot_longer(-dataset, names_to = "metric", values_to = "correlation")
### Mutual information
# mut_list <- lapply(names(mean_tab), function(dataset){
# mean_res <- mean_tab[[dataset]]
# mut <- cmi(mean_res)
# norm_mi <- mut$mi[ncol(mut$mi), nrow(mut$mi)]
# norm_vec <- rep(sqrt(norm_mi * norm_mi), ncol(mean_res))
# norm_mut <- mut$mi[,ncol(mut$mi)]/norm_vec %>% set_names(colnames(mean_res))
# }) %>% bind_rows() %>% set_rownames(names(mean_tab))
#
# mut_long <- mut_list %>% mutate(dataset = rownames(.)) %>% select(-batch) %>%
# pivot_longer(-dataset, names_to = "metric", values_to = "MI")
# # order by metrics type
lim_long$metric <- factor(lim_long$metric,levels = metric_order)
cor_long$metric <- factor(cor_long$metric,levels = metric_order)
cor_long_all$metric <- factor(cor_long_all$metric,levels = metric_order)
#mut_long$metric <- factor(mut_long$metric,levels = metric_order)
p <- ggplot(lim_long, aes(x = metric , y = batch_limit, fill = metric)) +
geom_boxplot(fill = cols[1:length(levels(as.factor(lim_long$metric)))], alpha = 0.5) +
geom_dotplot(binaxis='y', stackdir='center', dotsize=0.5) +
scale_fill_manual(values=cols) +
theme_ipsum(base_family = 'Helvetica') +
theme(axis.text.x = element_text(size=10, angle=45, hjust = 1))
p
## `stat_bindot()` using `bins = 30`. Pick better value with `binwidth`.
saveRDS(p, paste0(out_path_fig, "_limits.rds"))
p <- ggplot(cor_long, aes(x = metric , y = correlation, fill = metric)) +
geom_boxplot(fill = cols[1:length(levels(as.factor(cor_long$metric)))], alpha = 0.5) +
geom_dotplot(binaxis='y', stackdir='center', dotsize=0.5) +
scale_fill_manual(values=cols) +
theme_ipsum(base_family = 'Helvetica') +
theme(axis.text.x = element_text(size=10, angle=45, hjust = 1))
p
## `stat_bindot()` using `bins = 30`. Pick better value with `binwidth`.
saveRDS(p, paste0(out_path_fig, "_rel_cor.rds"))
p <- ggplot(cor_long_all, aes(x = metric , y = correlation, fill = metric)) +
geom_boxplot(fill = cols[1:length(levels(as.factor(cor_long$metric)))], alpha = 0.5) +
geom_dotplot(binaxis='y', stackdir='center', dotsize=0.5) +
scale_fill_manual(values=cols) +
theme_ipsum(base_family = 'Helvetica') +
theme(axis.text.x = element_text(size=10, angle=45, hjust = 1))
p
## `stat_bindot()` using `bins = 30`. Pick better value with `binwidth`.
saveRDS(p, paste0(out_path_fig, "_cor.rds"))
#join datasets
trade_long <- full_join(cor_long_all, lim_long, by = c("dataset", "metric"))
trade_sorted <- trade_long %>% group_by(metric) %>%
summarise(mean_correlation = mean(correlation),
mean_limit = mean(batch_limit),
correlation_sd = sd(correlation),
batch_limit_sd = sd(batch_limit))
lev_metric_level <- c(cell = c("cms.bmin", "cms.batch_id","cms.kmin",
"isi", "wisi", "lisi", "entropy", "mm"),
celltype = c("kbet", "asw", "graph.connectivity"),
global = "pcr")
trade_sorted$metric_level <- fct_collapse(factor(trade_sorted$metric),
cell = c("cms.bmin", "cms.batch_id",
"cms.kmin", "isi", "wisi",
"lisi", "entropy", "mm"),
celltype = c("kbet", "asw", "sw",
"casw", "graph.connectivity"),
global = "pcr")
## Warning: Unknown levels in `f`: sw, casw
ggplot(trade_sorted, aes(x=mean_limit, y=mean_correlation, color=metric)) +
geom_pointrange(aes(ymin=mean_correlation - correlation_sd,
ymax=mean_correlation + correlation_sd)) +
geom_pointrange(aes(xmin=mean_limit - batch_limit_sd,
xmax=mean_limit + batch_limit_sd)) +
scale_colour_manual(values=cols) +
ggtitle("Trade off: batch detection limit - correlation") +
theme_ipsum(base_family = 'Helvetica')
trade_top <- trade_sorted %>% filter(mean_limit < 0.6, mean_correlation > 0.7)
ggplot(trade_top, aes(x=mean_limit, y=mean_correlation, color=metric)) +
geom_pointrange(aes(ymin=mean_correlation - correlation_sd,
ymax=mean_correlation + correlation_sd)) +
geom_pointrange(aes(xmin=mean_limit - batch_limit_sd,
xmax=mean_limit + batch_limit_sd)) +
scale_colour_manual(values = cols) +
ggtitle("Trade off: batch detection limit - correlation") +
theme_ipsum(base_family = 'Helvetica')
#join datasets
p <- ggplot(trade_sorted, aes(x=mean_limit, y=mean_correlation, color=metric, shape = metric_level)) +
geom_point(size = 8, color = "black", ) +
geom_point(size = 7) +
scale_colour_manual(values=cols) +
ggtitle("Trade off: Sensitivity - scaling") +
theme_ipsum(base_family = 'Helvetica') +
theme(legend.box = "vertical")
p
saveRDS(p, paste0(out_path_fig, "_trade_off.rds"))
#join datasets
p <- ggplot(trade_sorted, aes(x=mean_limit, y=mean_correlation, color=metric, shape = metric_level)) +
geom_point(size = 8, color = "black", ) +
geom_point(size = 7) +
scale_colour_manual(values=cols) +
ggtitle("Trade off: Sensitivity - scaling") +
theme_ipsum(base_family = 'Helvetica') +
theme(legend.box = "vertical")
p
saveRDS(p, paste0(out_path_fig, "_trade_off_relative.rds"))
Check mean correlation
#scale means
scale_m <- lapply(names(mean_tab), function(dataset){
mean_res <- mean_tab[[dataset]]
scale_res <- mean_res %>% select(-batch) %>% scale() %>% as.data.frame()
}) %>% set_names(names(mean_tab))
cor_met <- lapply(names(scale_m), function(dataset){
scal_res <- scale_m[[dataset]]
cor_scal <- cor(scal_res, use = "complete.obs", method = "spearman")
})
cor_met_mean <- Reduce(`+`, cor_met)/length(cor_met)
corrplot(cor_met_mean,
type="upper",
order="original",
hclust.method = "complete",
col=brewer.pal(n=8, name="PuOr"),
addgrid.col = NA,
addCoef.col = "black",
diag = FALSE)
#save correlation
saveRDS(cor_met_mean, out_path_cor)
saveRDS(trade_sorted, out_path_res)
sep_trends <- function(dataset){
mean_res <- scale_means[[dataset]]
dat_nam <- dataset
cor_dat <- cor_long %>% dplyr::filter(dataset %in% dat_nam) %>%
arrange(match(metric, metric_order)) %>% select(correlation) %>%
abs() %>% signif(digits = 2)
label <- data.frame(
scaled_score = c(rep(-1, 12)),
batch = c(rep(2.5, 12)),
metric = metric_order,
label = c(paste0("R = ", cor_dat$correlation))
)
lim_ordered <- limit_tab %>% dplyr::filter(metric %in% metric_order) %>%
arrange(match(metric, metric_order))
find_scal <- function(met){
lim <- lim_ordered[lim_ordered$metric %in% met, dataset]
lim_score <- mean_res[mean_res$batch == lim[[1]], met]
}
mean_lim <- metric_order %>% map(find_scal) %>% unlist()
sens_limit <- data.frame("batch" = lim_ordered[,dataset][[1]],
"metric" = metric_order,
"scaled_score" = mean_lim,
"label" = paste0("limit = ", lim_ordered[,dataset][[1]]))
mean_long <- mean_res %>% pivot_longer(-batch, names_to = "metric",
values_to = "scaled_score")
mean_long <- mean_long %>% mutate(metric2 = metric)
cols_rep <- rep(cols[1: ncol(mean_res)-1], each = nrow(mean_res))
mean_long$metric <- factor(mean_long$metric,levels = metric_order)
p <- ggplot(mean_long, aes(x = batch, y = scaled_score)) +
geom_line(data=mean_long %>% dplyr::select(-metric), aes(group=metric2),
color="grey", size=0.5, alpha=0.5) +
geom_line( aes(color=metric), color=cols_rep, size=1.2 ) +
theme_ipsum(base_family = 'Helvetica') +
theme(
legend.position="none",
plot.title = element_text(size=14),
panel.grid = element_blank()
) +
geom_label(data = label, aes(label = label), size = 6) +
geom_mark_circle(data = sens_limit, aes(label = label, x = batch, y = scaled_score),
label.colour = "black", inherit.aes = FALSE,
label.fontsize = 16,
size = 1,
expand = unit(1, "mm"),
label.minwidth = unit(50, "mm")) +
ggtitle(dataset) +
facet_wrap(~metric)
p
}
template_sep <- c(
"#### {{nm}}\n",
"```{r scaling sep {{nm}}, echo = FALSE, fig.width = 8, fig.height = 7}\n",
"p <- sep_trends('{{nm}}')\n",
"p",
"saveRDS(p, paste0(out_path_fig, '_{{nm}}.rds'))",
"```\n",
"\n"
)
plots_sep <- lapply(names(scale_means),
function(nm) knitr::knit_expand(text = template_sep)
)