@@ -8,16 +8,15 @@ function RLCore.Experiment(
88)
99 rng = Random. MersenneTwister (seed)
1010
11- SHAPE = (8 ,8 )
12- inner_env = SnakeGameEnv (;action_style= FULL_ACTION_SET, shape= SHAPE, rng= rng)
11+ SHAPE = (8 , 8 )
12+ inner_env = SnakeGameEnv (; action_style = FULL_ACTION_SET, shape = SHAPE, rng = rng)
1313
1414 board_size = size (get_state (inner_env))
1515 N_FRAMES = 4
1616
17- env = inner_env |>
18- StateOverriddenEnv (
19- StackFrames (board_size... , N_FRAMES),
20- ) |>
17+ env =
18+ inner_env |>
19+ StateOverriddenEnv (StackFrames (board_size... , N_FRAMES),) |>
2120 StateCachedEnv
2221
2322 N_ACTIONS = length (get_actions (env))
@@ -36,7 +35,14 @@ function RLCore.Experiment(
3635 create_model () =
3736 Chain (
3837 x -> reshape (x, SHAPE... , :, size (x, ndims (x))),
39- CrossCor ((3 , 3 ), board_size[end ] * N_FRAMES => 16 , relu; stride = 1 , pad = 1 , init = init),
38+ CrossCor (
39+ (3 , 3 ),
40+ board_size[end ] * N_FRAMES => 16 ,
41+ relu;
42+ stride = 1 ,
43+ pad = 1 ,
44+ init = init,
45+ ),
4046 CrossCor ((3 , 3 ), 16 => 32 , relu; stride = 1 , pad = 1 , init = init),
4147 x -> reshape (x, :, size (x, ndims (x))),
4248 Dense (8 * 8 * 32 , 256 , relu; initW = init),
@@ -45,35 +51,35 @@ function RLCore.Experiment(
4551
4652 agent = Agent (
4753 policy = QBasedPolicy (
48- learner = DQNLearner (
49- approximator = NeuralNetworkApproximator (
50- model = create_model (),
51- optimizer = ADAM (0.001 ),
52- ),
53- target_approximator = NeuralNetworkApproximator (model = create_model ()),
54- update_freq = update_freq,
55- γ = 0.99f0 ,
56- update_horizon = 1 ,
57- batch_size = 32 ,
58- stack_size = nothing ,
59- min_replay_history = 20_000 ,
60- loss_func = huber_loss,
61- target_update_freq = 8_000 ,
62- rng = rng,
63- ),
64- explorer = EpsilonGreedyExplorer (
65- ϵ_init = 1.0 ,
66- ϵ_stable = 0.01 ,
67- decay_steps = 250_000 ,
68- kind = :linear ,
69- rng = rng,
54+ learner = DQNLearner (
55+ approximator = NeuralNetworkApproximator (
56+ model = create_model (),
57+ optimizer = ADAM (0.001 ),
7058 ),
59+ target_approximator = NeuralNetworkApproximator (model = create_model ()),
60+ update_freq = update_freq,
61+ γ = 0.99f0 ,
62+ update_horizon = 1 ,
63+ batch_size = 32 ,
64+ stack_size = nothing ,
65+ min_replay_history = 20_000 ,
66+ loss_func = huber_loss,
67+ target_update_freq = 8_000 ,
68+ rng = rng,
7169 ),
70+ explorer = EpsilonGreedyExplorer (
71+ ϵ_init = 1.0 ,
72+ ϵ_stable = 0.01 ,
73+ decay_steps = 250_000 ,
74+ kind = :linear ,
75+ rng = rng,
76+ ),
77+ ),
7278 trajectory = CircularCompactSALRTSALTrajectory (
7379 capacity = 500_000 ,
7480 state_type = Float32,
7581 state_size = (board_size... , N_FRAMES),
76- legal_actions_mask_size= (N_ACTIONS,)
82+ legal_actions_mask_size = (N_ACTIONS,),
7783 ),
7884 )
7985
@@ -90,7 +96,8 @@ function RLCore.Experiment(
9096 steps_per_episode,
9197 DoEveryNStep (update_freq) do t, agent, env
9298 with_logger (lg) do
93- @info " training" loss = agent. policy. learner. loss log_step_increment = update_freq
99+ @info " training" loss = agent. policy. learner. loss log_step_increment =
100+ update_freq
94101 end
95102 end ,
96103 DoEveryNEpisode () do t, agent, env
@@ -109,4 +116,4 @@ function RLCore.Experiment(
109116 You can view the tensorboard logs with `tensorboard --logdir $(joinpath (save_dir, " tb_log" )) `
110117 """
111118 Experiment (agent, env, stop_condition, hook, description)
112- end
119+ end
0 commit comments