Skip to content

Commit e2d951d

Browse files
committed
reproduce fig 18.6 #207
1 parent f83b921 commit e2d951d

File tree

3 files changed

+103
-1
lines changed

3 files changed

+103
-1
lines changed
107 KB
Loading

code/fig18.5/err_vs_log_lambda.png

16.1 KB
Loading

code/fig18.5/script.jl

+103-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using DelimitedFiles
22
using GLMNet
3-
3+
using StatsBase
4+
using Plots
5+
using LaTeXStrings
46
# import data
57
folder = "data/Leukemia/"
68
raw_X = readdlm(folder*"data_set_ALL_AML_train.txt", '\t')
@@ -18,3 +20,103 @@ train_y = hcat(vcat(fill(1, 27, 1), fill(0, 38-27, 1)),
1820
grid = collect(range(exp(-8), exp(-1), length=100))
1921
lasso_path = glmnet(train_X, train_y, Binomial(), lambda=grid)
2022
elnet_path = glmnet(train_X, train_y, Binomial(), lambda=grid, alpha=0.8)
23+
24+
# calculate training error
25+
lasso_train_val = predict(lasso_path, train_X)
26+
elnet_train_val = predict(elnet_path, train_X)
27+
lasso_train_yhat = predict(lasso_path, train_X) .> 0
28+
elnet_train_yhat = predict(elnet_path, train_X) .> 0
29+
lasso_train_err = [sum(lasso_train_yhat[:, i] .!= train_y[:, 2]) for i=1:size(lasso_train_yhat, 2)] / length(train_y[:, 2])
30+
elnet_train_err = [sum(elnet_train_yhat[:, i] .!= train_y[:, 2]) for i=1:size(elnet_train_yhat, 2)] / length(train_y[:, 2])
31+
32+
# calculate deviance
33+
function calc_dev(pred_val, grid)
34+
dev = zeros(length(grid))
35+
for i = 1:length(grid)
36+
phat = 1 ./ (1 .+ exp.(pred_val[:, i]))
37+
dev[i] = -sum( phat .* log.(phat) + (1 .- phat) .* log.(1 .- phat) )
38+
end
39+
return dev
40+
end
41+
lasso_train_dev = calc_dev(lasso_train_val, grid)
42+
43+
# #############################
44+
# test
45+
# #############################
46+
47+
# import test data
48+
raw_test_X = readdlm(folder*"data_set_ALL_AML_independent.txt", '\t')
49+
test_X = raw_test_X[:, 3:2:end]
50+
test_id = Int.(test_X[1, :])
51+
test_X = Array{Float64, 2}(test_X[2:end, :]')
52+
lasso_test_val = predict(lasso_path, test_X)
53+
elnet_test_val = predict(elnet_path, test_X)
54+
lasso_yhat = predict(lasso_path, test_X) .> 0
55+
elnet_yhat = predict(elnet_path, test_X) .> 0
56+
# get from `Leuk_ALL_AML.test.cls` !!! strange thing!! it has 35 test samples, but actually it is only 34 in any other places.
57+
# raw_test_y = readdlm(folder*"Leuk_ALL_AML.test.cls")
58+
# test_y = raw_test_y[2,:]
59+
60+
# get from `table_ALL_AML_samples.txt`
61+
test_y = vcat(fill(0, 49-39+1, 1), fill(1, 54-50+1, 1), fill(0, 56-55+1), fill(1, 58-57+1, 1), 0, fill(1, 66-60+1, 1), fill(0, 72-67+1, 1))
62+
# sort according to the test_id
63+
# test_y = test_y[sortperm(test_id)] !!
64+
test_y = test_y[Int.(tiedrank(test_id))]
65+
lasso_test_err = [sum(lasso_yhat[:, i] .!= test_y) for i=1:size(lasso_yhat, 2)] / length(test_y)
66+
elnet_test_err = [sum(elnet_yhat[:, i] .!= test_y) for i=1:size(elnet_yhat, 2)] / length(test_y)
67+
68+
lasso_test_dev = calc_dev(lasso_test_val, grid)
69+
# 10-fold on the train set
70+
# cross validation (c.f. to https://github.com/szcf-weiya/ESL-CN/blob/99f1e9fd4b8c8c80fa0ff281f0e082cb810a54e0/code/LDA/diagonalLDA.jl#L99-L128)
71+
# X is Nxp
72+
function cv_err(X::Array{Float64, 2}, y::Array{Int, 2}, lambda::Array{Float64, 1}; nfold = 10)
73+
N = size(X, 1)
74+
folds = div_into_folds(N, K = nfold)
75+
err = zeros(nfold, length(lambda))
76+
dev = zeros(nfold, length(lambda))
77+
for k = 1:nfold
78+
test_idx = folds[k]
79+
train_idx = setdiff(1:N, test_idx)
80+
cvlasso = glmnet(X[train_idx, :], y[train_idx, :], Binomial(), lambda = lambda)
81+
# calculate err
82+
pred_val = predict(cvlasso, X[test_idx, :])
83+
yhat = pred_val .> 0
84+
err[k, :] = [sum(yhat[:, i] .!= y[test_idx, 2]) for i=1:length(lambda)]
85+
dev[k, :] = calc_dev(pred_val, grid)
86+
end
87+
return [sum(err, dims=1) / N, sum(dev, dims=1)]
88+
end
89+
90+
# 1:N divide into K-fold
91+
function div_into_folds(N::Int; K::Int = 10)
92+
# maximum quota per fold
93+
n = Int(ceil(N/K))
94+
# number folds for the maximum quota
95+
k = N - (n-1)*K
96+
# number fols for n-1 quota: K-k
97+
folds = Array{Array{Int, 1}, 1}(undef, K)
98+
for i = 1:k
99+
folds[i] = collect(n*(i-1)+1:n*i)
100+
end
101+
for i = 1:K-k
102+
folds[k+i] = collect((n-1)*(i-1)+1:(n-1)*i) .+ n*k
103+
end
104+
return folds
105+
end
106+
107+
# 10-fold cv error
108+
lasso_cv10_err, lasso_cv10_dev = cv_err(train_X, train_y, grid)
109+
p1 = plot(log.(grid), lasso_cv10_err[:], color = "purple", label="10-fold CV", linewidth=3, legend=:topleft)
110+
plot!(p1, log.(grid), lasso_train_err, color = "orange", label="Training", linewidth=3)
111+
plot!(p1, log.(grid), lasso_test_err, color = "skyblue", label="Test", linewidth=3)
112+
xlabel!(p1, L"\log \lambda")
113+
ylabel!(p1, "Misclassification Error")
114+
# savefig("err_vs_log_lambda.png")
115+
116+
p2 = plot(log.(grid), lasso_cv10_dev[:], color = "purple",linewidth=3, legend = :false, xlab = L"\log\lambda", ylab = "Deviance", fontsize=20)
117+
plot!(p2, log.(grid), lasso_train_dev, color = "orange", linewidth=3)
118+
plot!(p2, log.(grid), lasso_test_dev, color = "skyblue", linewidth=3)
119+
120+
plot(p1, p2, dpi=300)
121+
# savefig("err_and_dev_vs_log_lambda.pdf")
122+
savefig("err_and_dev_vs_log_lambda.png")

0 commit comments

Comments
 (0)