I want to apply a function (ratefunc()
) to a grouped data frame which returns changing column names dependent on the result:
library(data.table)
library(dplyr)
dt <- data.table:::data.table(
group=c(rep("A", 3), rep("B", 3)),
x=c(1:3, 6:8),
y=c(4:6, 9:11)
)
ratefunc <- function(data, x_col="x", y_col="y") {
res <- sum(data[[x_col]] + data[[y_col]])
if (res < 25) {
return(
data.frame(a=rep("a", nrow(data)))
)
} else {
return(
data.frame(b=rep("b", nrow(data)))
)
}
}
dplyr
returns the desired result by
- assigning a new column to each column name returned and
- setting
NA
for the observation of the other groups
dt %>%
group_by(group) %>%
group_modify(
~ratefunc(data=.)
)
Console output:
# A tibble: 6 × 3
# Groups: group [2]
group a b
<chr> <chr> <chr>
1 A a NA
2 A a NA
3 A a NA
4 B NA b
5 B NA b
6 B NA b
data.table
instead coerces the results into one column ignoring the differently named columns:
dt[
,
ratefunc(
data=.SD
)
,
by=group
]
Console output:
group a
<char> <char>
1: A a
2: A a
3: A a
4: B b
5: B b
6: B b
How can I get the identical result as in dplyr
when using data.table
?
And how is the approach of selecting the data frame columns by column name in ratefunc()
(i.e. x_col
and y_col
being string inputs) to be evaluated in contrast to selecting them directly (i.e. x_col
and y_col
as vector/column inputs)?
Your ratefunc
is not something I would create. I would do something like this. This could be wrapped in a function if necessary.
dt[, tmp := fifelse(sum(Reduce(`+`, .SD)) < 25, "a", "b"),
by = group, .SDcols = c("x", "y")]
dt[, ind := .I]
dcast(dt, group + ind ~ tmp, value.var = "tmp")
#Key: <group, ind>
# group ind a b
# <char> <int> <char> <char>
#1: A 1 a <NA>
#2: A 2 a <NA>
#3: A 3 a <NA>
#4: B 4 <NA> b
#5: B 5 <NA> b
#6: B 6 <NA> b
Edit adressing additional requirements from comments:
dt[, colnames := fifelse(sum(Reduce(`+`, .SD)) < 25, "col1", "col2"),
by = group, .SDcols = c("x", "y")]
dt[data.table(colnames = c("col1", "col2"),
colvalues = c("a", "b")),
colvalues := i.colvalues, on = .(colnames)]
dt[, ind := .I]
dcast(dt, group + ind ~ colnames, value.var = "colvalues")
# Key: <group, ind>
# group ind col1 col2
# <char> <int> <char> <char>
# 1: A 1 a <NA>
# 2: A 2 a <NA>
# 3: A 3 a <NA>
# 4: B 4 <NA> b
# 5: B 5 <NA> b
# 6: B 6 <NA> b
4
I don’t think it’s a good idea to create a function for this problem but in case you believe the opposite, you will need to modify it (see end this answer):
Find a way to solve your problem below:
In case you would like to add the new columns to your dataset, use
dt[, a := if(sum(.SD) > 25L) "a", by=group, .SDcols=c("x", "y")][is.na(a), b := "b"]
or the function below (more dynamic/flexible):
ratefunc = function(dt, x_col="x", y_col="y", nm1, nm2, val1, val2) {
dt[, (nm1) := if(sum(x_col, y_col) > 25L) val1, by=group, env=list(x_col=x_col, y_col=y_col, val1=I(val1))]
dt[is.na(nm1), (nm2) := val2, env=list(nm1=nm1, val2=I(val2))][]
}
ratefunc(dt=dt, nm1="a", nm2="b", val1="a", val2="b")
group x y a b
<char> <int> <int> <char> <char>
1: A 1 4 <NA> b
2: A 2 5 <NA> b
3: A 3 6 <NA> b
4: B 6 9 a <NA>
5: B 7 10 a <NA>
6: B 8 11 a <NA>
Otherwise, use:
dt[, if(sum(.SD) > 25)
.(a = rep("a", .N), b = NA_character_)
else
.(a = NA_character_, b = rep("b", .N)),
by = group,
.SDcols = c("x", "y")]
group a b
<char> <char> <char>
1: A <NA> b
2: A <NA> b
3: A <NA> b
4: B a <NA>
5: B a <NA>
6: B a <NA>
Find the modified version of your function below in case prefer to use it instead.
ratefunc <- function(data, x_col="x", y_col="y") {
res <- sum(data[[x_col]] + data[[y_col]])
if (res < 25) {
return(
list(a=rep("a", nrow(data)), b=NA_character_) # modified line
)
} else {
return(
list(a = NA_character_, b=rep("b", nrow(data))) # modified line
)
}
}
dt[, ratefunc(data=.SD), by=group]
7
You can try a better ratefunc
.
> ratefunc <- function(data, x_col="x", y_col="y") {
+ res <- sum(data[[x_col]] + data[[y_col]])
+ o <- t(array(, c(2L, nrow(data)), list(c('a', 'b'), NULL)))
+ if (res < 25) {
+ o[, 'a'] <- 'a'
+ } else {
+ o[, 'b'] <- 'b'
+ }
+ o |> as.data.frame()
+ }
>
> dt[
+ ,
+ ratefunc(
+ data=.SD
+ )
+ ,
+ by=group
+ ]
group a b
<char> <char> <char>
1: A a <NA>
2: A a <NA>
3: A a <NA>
4: B <NA> b
5: B <NA> b
6: B <NA> b
3
Here is another approach:
library(data.table)
dt <- data.table:::data.table(
group=c(rep("A", 3), rep("B", 3)),
x=c(1:3, 6:8),
y=c(4:6, 9:11)
)
ratefunc <- function(data, x_col="x", y_col="y") {
DT <- copy(data) # avoid modifying dt
DT[, c("a", "b") := .(
fifelse(sum(.SD) < 25L, yes = "a", no = NA_character_),
fifelse(sum(.SD) < 25L, yes = NA_character_, no = "b")
), .SDcols = c(x_col, y_col) , by = group][, c(x_col, y_col) := NULL]
}
resultDT <- ratefunc(dt, x_col="x", y_col="y")
resultDT[]
Alternative: directly applied to dt
:
x_col <- "x"
y_col <- "y"
dt[, c("a", "b") := .(
fifelse(sum(.SD) < 25L, yes = "a", no = NA_character_),
fifelse(sum(.SD) < 25L, yes = NA_character_, no = "b")
), .SDcols = c(x_col, y_col) , by = group][, c(x_col, y_col) := NULL]
A benchmark (you might need to check how it scales with your data):
Unit: milliseconds
expr min lq mean median uq max neval
B.ChristianKamgang 1.0797 1.0797 1.0797 1.0797 1.0797 1.0797 1
Roland 4.5069 4.5069 4.5069 4.5069 4.5069 4.5069 1
jay.sf 15.7749 15.7749 15.7749 15.7749 15.7749 15.7749 1
ismirsehregal 1.5379 1.5379 1.5379 1.5379 1.5379 1.5379 1
library(data.table)
library(microbenchmark)
dt <- data.table:::data.table(
group = c(rep("A", 3), rep("B", 3)),
x = c(1:3, 6:8),
y = c(4:6, 9:11)
)
dt1 <- copy(dt)
dt2 <- copy(dt)
dt3 <- copy(dt)
dt4 <- copy(dt)
ratefunc_jay.sf <- function(data, x_col = "x", y_col = "y") {
res <- sum(data[[x_col]] + data[[y_col]])
o <- t(array(, c(2L, nrow(data)), list(c('a', 'b'), NULL)))
if (res < 25) {
o[, 'a'] <- 'a'
} else {
o[, 'b'] <- 'b'
}
o |> as.data.frame()
}
microbenchmark(
B.ChristianKamgang = {
res <- dt4[, if(sum(.SD) > 25)
.(a = rep("a", .N), b = NA_character_)
else
.(a = NA_character_, b = rep("b", .N)),
by = group,
.SDcols = c("x", "y")]
},
Roland = {
dt1[, tmp := fifelse(sum(Reduce(`+`, .SD)) < 25, "a", "b"), by = group, .SDcols = c("x", "y")]
dt1[, ind := .I]
dcast(dt1, group + ind ~ tmp, value.var = "tmp")
},
jay.sf = {
dt2[, ratefunc_jay.sf(data = .SD), by = group]
},
ismirsehregal = {
dt3[, c("a", "b") := .(
fifelse(sum(.SD) < 25L, yes = "a", no = NA_character_),
fifelse(sum(.SD) < 25L, yes = NA_character_, no = "b")
), .SDcols = c("x", "y") , by = group][, c("x", "y") := NULL]
},
times = 1L
)
2