|
| 1 | + |
| 2 | +using Pkg; |
| 3 | +Pkg.activate("dev"); |
| 4 | + |
| 5 | +using AlgorithmicRecourseDynamics |
| 6 | +using Colors |
| 7 | +using CounterfactualExplanations |
| 8 | +using Distributions |
| 9 | +using Flux |
| 10 | +using Luxor |
| 11 | +using StatsBase: sample |
| 12 | +using Random |
| 13 | + |
| 14 | +const julia_colors = Dict( |
| 15 | + :blue => Luxor.julia_blue, |
| 16 | + :red => Luxor.julia_red, |
| 17 | + :green => Luxor.julia_green, |
| 18 | + :purple => Luxor.julia_purple, |
| 19 | +) |
| 20 | + |
| 21 | +function get_data(N=1000, xmax=2) |
| 22 | + X, ys = make_blobs( |
| 23 | + N, 2; |
| 24 | + centers=2, as_table=false, center_box=(-xmax => xmax), cluster_std=0.1 |
| 25 | + ) |
| 26 | + ys .= ys .== 2 |
| 27 | + X = X' |
| 28 | + xs = Flux.unstack(X, 2) |
| 29 | + data = zip(xs, ys) |
| 30 | + counterfactual_data = CounterfactualData(X, ys') |
| 31 | + return counterfactual_data, data |
| 32 | +end |
| 33 | +plot() |
| 34 | +scatter!(counterfactual_data) |
| 35 | + |
| 36 | +function logo_picture(; |
| 37 | + ndots=3, |
| 38 | + frame_size=500, |
| 39 | + ms=frame_size // 10, |
| 40 | + mcolor=(:red, :green, :purple), |
| 41 | + margin=0.1, |
| 42 | + fun=f(x) = x * cos(x), |
| 43 | + xmax=2.5, |
| 44 | + noise=0.5, |
| 45 | + ged_data=get_data, |
| 46 | + ntrue=50, |
| 47 | + gt_color=julia_colors[:blue], |
| 48 | + gt_stroke_size=5, |
| 49 | + interval_color=julia_colors[:blue], |
| 50 | + interval_alpha=0.2, |
| 51 | + seed=2022 |
| 52 | +) |
| 53 | + |
| 54 | + # Setup |
| 55 | + n_mcolor = length(mcolor) |
| 56 | + mcolor = getindex.(Ref(julia_colors), mcolor) |
| 57 | + Random.seed!(seed) |
| 58 | + |
| 59 | + # Data |
| 60 | + x, y = get_data(xmax=xmax, noise=noise, fun=fun) |
| 61 | + train, test = partition(eachindex(y), 0.4, 0.4, shuffle=true) |
| 62 | + xtrue = range(-xmax, xmax, ntrue) |
| 63 | + ytrue = fun.(xtrue) |
| 64 | + |
| 65 | + # Conformal Prediction |
| 66 | + Model = @load LinearRegressor pkg = MLJLinearModels |
| 67 | + degree_polynomial = 5 |
| 68 | + polynomial_features(x, degree::Int) = reduce(hcat, map(i -> x .^ i, 1:degree)) |
| 69 | + pipe = (x -> MLJBase.table(polynomial_features(x, degree_polynomial))) |> Model() |
| 70 | + conf_model = conformal_model(pipe; coverage=0.95) |
| 71 | + mach = machine(conf_model, x, y) |
| 72 | + fit!(mach, rows=train) |
| 73 | + yhat = predict(mach, x[test]) |
| 74 | + y_lb = [y[1] for y in yhat] |
| 75 | + y_ub = [y[2] for y in yhat] |
| 76 | + |
| 77 | + # Logo |
| 78 | + idx = sample(test, ndots, replace=false) |
| 79 | + xplot, yplot = (x[idx], y[idx]) |
| 80 | + _scale = (frame_size / (2 * maximum(x))) * (1 - margin) |
| 81 | + |
| 82 | + # Ground truth: |
| 83 | + setline(gt_stroke_size) |
| 84 | + sethue(gt_color) |
| 85 | + true_points = [Point((_scale .* (x, y))...) for (x, y) in zip(xtrue, ytrue)] |
| 86 | + poly(true_points[1:(end-1)], action=:stroke) |
| 87 | + |
| 88 | + # Data |
| 89 | + data_plot = zip(xplot, yplot) |
| 90 | + for i = 1:length(data_plot) |
| 91 | + _x, _y = _scale .* collect(data_plot)[i] |
| 92 | + color_idx = i % n_mcolor == 0 ? n_mcolor : i % n_mcolor |
| 93 | + sethue(mcolor[color_idx]...) |
| 94 | + circle(Point(_x, _y), ms, action=:fill) |
| 95 | + end |
| 96 | + |
| 97 | + # Prediction interval: |
| 98 | + _order_lb = sortperm(x[test]) |
| 99 | + _order_ub = reverse(_order_lb) |
| 100 | + lb = [ |
| 101 | + Point((_scale .* (x, y))...) for (x, y) in zip(x[test][_order_lb], y_lb[_order_lb]) |
| 102 | + ] |
| 103 | + ub = [ |
| 104 | + Point((_scale .* (x, y))...) for (x, y) in zip(x[test][_order_ub], y_ub[_order_ub]) |
| 105 | + ] |
| 106 | + setcolor(sethue(interval_color)..., interval_alpha) |
| 107 | + poly(vcat(lb, ub), action=:fill) |
| 108 | + |
| 109 | +end |
| 110 | + |
| 111 | +function draw_small_logo(filename="docs/src/assets/logo.svg"; width=500) |
| 112 | + frame_size = width |
| 113 | + Drawing(frame_size, frame_size, filename) |
| 114 | + origin() |
| 115 | + logo_picture(frame_size=frame_size) |
| 116 | + finish() |
| 117 | + preview() |
| 118 | +end |
| 119 | + |
| 120 | +function draw_wide_logo_new( |
| 121 | + filename="docs/src/assets/wide_logo.png"; |
| 122 | + _pkg_name="Conformal Prediction", |
| 123 | + font_size=150, |
| 124 | + font_family="Tamil MN", |
| 125 | + font_fill="transparent", |
| 126 | + font_color=Luxor.julia_blue, |
| 127 | + bg_color="transparent", |
| 128 | + picture_kwargs... |
| 129 | +) |
| 130 | + |
| 131 | + # Setup: |
| 132 | + height = Int(round(font_size * 2.4)) |
| 133 | + fontsize(font_size) |
| 134 | + fontface(font_family) |
| 135 | + strs = split(_pkg_name) |
| 136 | + text_col_width = Int(round(maximum(map(str -> textextents(str)[3], strs)) * 1.05)) |
| 137 | + width = Int(round(height + text_col_width)) |
| 138 | + cw = [height, text_col_width] |
| 139 | + cells = Luxor.Table(height, cw) |
| 140 | + ms = Int(round(height / 10)) |
| 141 | + gt_stroke_size = Int(round(height / 50)) |
| 142 | + |
| 143 | + Drawing(width, height, filename) |
| 144 | + origin() |
| 145 | + background(bg_color) |
| 146 | + |
| 147 | + # Picture: |
| 148 | + @layer begin |
| 149 | + translate(cells[1]) |
| 150 | + logo_picture( |
| 151 | + frame_size=height, |
| 152 | + margin=0.1, |
| 153 | + ms=ms, |
| 154 | + gt_stroke_size=gt_stroke_size, |
| 155 | + picture_kwargs..., |
| 156 | + ) |
| 157 | + end |
| 158 | + |
| 159 | + # Text: |
| 160 | + @layer begin |
| 161 | + translate(cells[2]) |
| 162 | + fontsize(font_size) |
| 163 | + fontface(font_family) |
| 164 | + tiles = Tiler(cells.colwidths[2], height, length(strs), 1) |
| 165 | + for (pos, n) in tiles |
| 166 | + @layer begin |
| 167 | + translate(pos) |
| 168 | + setline(Int(round(gt_stroke_size / 5))) |
| 169 | + sethue(font_fill) |
| 170 | + textoutlines(strs[n], O, :path, valign=:middle, halign=:center) |
| 171 | + sethue(font_color) |
| 172 | + strokepath() |
| 173 | + end |
| 174 | + end |
| 175 | + end |
| 176 | + |
| 177 | + finish() |
| 178 | + preview() |
| 179 | +end |
| 180 | + |
| 181 | +draw_wide_logo_new() |
0 commit comments