-
Notifications
You must be signed in to change notification settings - Fork 2
/
Snakefile
122 lines (112 loc) · 3.5 KB
/
Snakefile
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
USE_SLURM = config.get("USE_SLURM", False)
N_ITER = config.get("N_ITER", 100)
MODELS = ["dummy", "linear", "mlp", "knn", "gbdt"]
rule all:
input:
"results/cas_dataset.csv",
expand("results/{model_name}_model/predictions.csv", model_name=MODELS),
"results/summary",
rule download_data:
output:
"data/cas_dataset.csv"
shell:
"cas_data download {output}"
rule prepare_data:
input:
"data/cas_dataset.csv"
output:
"results/cas_dataset.csv"
shell:
"cas_data prepare {input} -o {output}"
rule fit:
input:
"results/cas_dataset.csv"
output:
"results/{model_type}_model/model.pickle"
threads: 1 if USE_SLURM else 8
params:
cores_per_worker=8 if USE_SLURM else 4,
n_workers=lambda wildcards, threads: 50 if USE_SLURM else max(1, threads // 4),
use_slurm="--use-slurm" if USE_SLURM else ""
shell:
"""
models fit {input} {output} --model-type {wildcards.model_type} \
--n-iter {N_ITER} \
--n-workers {params.n_workers} \
--cores-per-worker {params.cores_per_worker} \
--mem-per-worker "6GB" \
{params.use_slurm}
"""
rule fit_knn:
input:
"results/cas_dataset.csv"
output:
"results/knn_model/model.pickle"
threads: 1 if USE_SLURM else 8
params:
n_workers=lambda wildcards, threads: 10 if USE_SLURM else max(1, threads // 4),
use_slurm="--use-slurm" if USE_SLURM else ""
shell:
"""
models fit {input} {output} --model-type knn \
--n-iter {N_ITER} \
--n-workers {params.n_workers} \
--mem-per-worker "10GB" \
{params.use_slurm}
"""
rule fit_mlp:
input:
"results/cas_dataset.csv"
output:
"results/mlp_model/model.pickle"
threads: 1 if USE_SLURM else 8
params:
n_workers=lambda wildcards, threads: 50 if USE_SLURM else max(1, threads // 4),
use_slurm="--use-slurm" if USE_SLURM else ""
shell:
"""
models fit {input} {output} --model-type mlp \
--n-iter {N_ITER} \
--n-workers {params.n_workers} \
--cores-per-worker 4 \
--mem-per-worker "4GB" \
--walltime 0-03:00 \
{params.use_slurm}
"""
rule fit_gbdt:
input:
"results/cas_dataset.csv"
output:
"results/gbdt_model/model.pickle"
threads: 1 if USE_SLURM else 8
params:
cores_per_worker=8 if USE_SLURM else 4,
n_workers=lambda wildcards, threads: 25 if USE_SLURM else max(1, threads // 4),
use_slurm="--use-slurm" if USE_SLURM else ""
shell:
"""
models fit {input} {output} --model-type gbdt \
--n-iter {N_ITER} \
--n-workers {params.n_workers} \
--cores-per-worker {params.cores_per_worker} \
--mem-per-worker "8GB" \
{params.use_slurm}
"""
rule predict:
input:
"results/cas_dataset.csv",
"results/{model_name}_model/model.pickle"
output:
"results/{model_name}_model/predictions.csv"
shell:
"models predict {input} -o {output}"
rule evaluate:
input:
"results/cas_dataset.csv",
expand("results/{model_name}_model/predictions.csv", model_name=MODELS)
output:
directory("results/summary")
params:
labels=" ".join(MODELS)
shell:
"evaluate {output} {input} -l {params.labels}"