1+ function RLCore. Experiment (
2+ :: Val{:JuliaRL} ,
3+ :: Val{:DQN} ,
4+ :: Val{:SnakeGame} ,
5+ :: Nothing ;
6+ save_dir = nothing ,
7+ seed = 123 ,
8+ )
9+ rng = Random. MersenneTwister (seed)
10+
11+ SHAPE = (8 ,8 )
12+ inner_env = SnakeGameEnv (;action_style= FULL_ACTION_SET, shape= SHAPE, rng= rng)
13+
14+ board_size = size (get_state (inner_env))
15+ N_FRAMES = 4
16+
17+ env = inner_env |>
18+ StateOverriddenEnv (
19+ StackFrames (board_size... , N_FRAMES),
20+ ) |>
21+ StateCachedEnv
22+
23+ N_ACTIONS = length (get_actions (env))
24+
25+ if isnothing (save_dir)
26+ t = Dates. format (now (), " yyyy_mm_dd_HH_MM_SS" )
27+ save_dir = joinpath (pwd (), " checkpoints" , " SnakeGame_$(t) " )
28+ end
29+
30+ lg = TBLogger (joinpath (save_dir, " tb_log" ), min_level = Logging. Info)
31+
32+ init = glorot_uniform (rng)
33+
34+ update_freq = 4
35+
36+ create_model () =
37+ Chain (
38+ x -> reshape (x, SHAPE... , :, size (x, ndims (x))),
39+ CrossCor ((3 , 3 ), board_size[end ] * N_FRAMES => 16 , relu; stride = 1 , pad = 1 , init = init),
40+ CrossCor ((3 , 3 ), 16 => 32 , relu; stride = 1 , pad = 1 , init = init),
41+ x -> reshape (x, :, size (x, ndims (x))),
42+ Dense (8 * 8 * 32 , 256 , relu; initW = init),
43+ Dense (256 , N_ACTIONS; initW = init),
44+ ) |> cpu
45+
46+ agent = Agent (
47+ policy = QBasedPolicy (
48+ learner = DQNLearner (
49+ approximator = NeuralNetworkApproximator (
50+ model = create_model (),
51+ optimizer = ADAM (0.001 ),
52+ ),
53+ target_approximator = NeuralNetworkApproximator (model = create_model ()),
54+ update_freq = update_freq,
55+ γ = 0.99f0 ,
56+ update_horizon = 1 ,
57+ batch_size = 32 ,
58+ stack_size = nothing ,
59+ min_replay_history = 20_000 ,
60+ loss_func = huber_loss,
61+ target_update_freq = 8_000 ,
62+ rng = rng,
63+ ),
64+ explorer = EpsilonGreedyExplorer (
65+ ϵ_init = 1.0 ,
66+ ϵ_stable = 0.01 ,
67+ decay_steps = 250_000 ,
68+ kind = :linear ,
69+ rng = rng,
70+ ),
71+ ),
72+ trajectory = CircularCompactSALRTSALTrajectory (
73+ capacity = 500_000 ,
74+ state_type = Float32,
75+ state_size = (board_size... , N_FRAMES),
76+ legal_actions_mask_size= (N_ACTIONS,)
77+ ),
78+ )
79+
80+ evaluation_result = []
81+ EVALUATION_FREQ = 100_000
82+ N_CHECKPOINTS = 3
83+
84+ total_reward_per_episode = TotalRewardPerEpisode ()
85+ time_per_step = TimePerStep ()
86+ steps_per_episode = StepsPerEpisode ()
87+ hook = ComposedHook (
88+ total_reward_per_episode,
89+ time_per_step,
90+ steps_per_episode,
91+ DoEveryNStep (update_freq) do t, agent, env
92+ with_logger (lg) do
93+ @info " training" loss = agent. policy. learner. loss log_step_increment = update_freq
94+ end
95+ end ,
96+ DoEveryNEpisode () do t, agent, env
97+ with_logger (lg) do
98+ @info " training" episode_length = steps_per_episode. steps[end ] reward =
99+ total_reward_per_episode. rewards[end ] log_step_increment = 0
100+ end
101+ end ,
102+ )
103+
104+ N_TRAINING_STEPS = 1_000_000
105+ stop_condition = StopAfterStep (N_TRAINING_STEPS)
106+ description = """
107+ # Play Single Agent SnakeGame with DQN
108+
109+ You can view the tensorboard logs with `tensorboard --logdir $(joinpath (save_dir, " tb_log" )) `
110+ """
111+ Experiment (agent, env, stop_condition, hook, description)
112+ end
0 commit comments