I’ve just about exhausted every possible angle I can think of on this problem and I am not making any progress. I have a feeling it may be a compatibility issue but I am not sure? Anyway, I have run the code in the example linked below:
https://github.com/tensorflow/agents/blob/528cef7c4aedf54158a0564fdca446fe9942aa2a/docs/tutorials/1_dqn_tutorial.ipynb
more or less line for line. I ran this in VS code in a conda environment, having installed pip and then pip installed the appropriate packages within the environment. This has happened twice in 2 projects now. The code runs perfectly until it reaches the dataset in the reverb buffer and then it simply freezes and fails to progress any further without ever throwing an error or leaving any signs as to what might be happening. The exact same problem occurred when I ran my own version of this code in a project I am doing. Everything grinds to a halt at next(iterator)
My code can be seen below:
##RL - agent learns to perform actions in an environment so as to maximize a reward
##2 main components 1) env 2) agent
## agent + environment continuously interact. at each time_step agent takes an action based on its policy where s is the current observation from the environment and receives a reward rt+1 and the next observation st+1 from the env. Goal to improve policy so as to max sum of rewards
### Distinguish state of environment and the observation which is part of environment state that agent can see
from __future__ import absolute_import, division, print_function
import os
os.environ['TF_USE_LEGACY_KERAS']='1'
# os.environ['TF_ENABLE_ONEDNN_OPTS=0']
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import base64
import imageio
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image
import pyvirtualdisplay
import reverb
import tensorflow as tf
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import py_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import sequential
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.trajectories import trajectory
from tf_agents.specs import tensor_spec
from tf_agents.utils import common
display = pyvirtualdisplay.Display(visible=0,size=(1400,900)).start()
print(tf.test.is_built_with_gpu_support())
###HYPERPARAMETERS
num_iterations = 20000
initial_collect_steps=100
collect_steps_per_iteration=1
replay_buffer_max_length=100000
batch_size=64
learning_rate=1e-3
log_interval=200
num_eval_episodes=10
eval_interval=1000
###ENVIRONMENT
env_name="CartPole-v0"
env = suite_gym.load(env_name)
env.reset()
image =PIL.Image.fromarray(env.render())
# image.show()
time_step=env.reset()
action=np.array(1,dtype=np.int32)
next_time_step=env.step(action)
train_py_env=suite_gym.load(env_name)
eval_py_env =suite_gym.load(env_name)
train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)
fc_layer_params = (100, 50)
action_tensor_spec=tensor_spec.from_spec(env.action_spec())
num_actions=action_tensor_spec.maximum - action_tensor_spec.minimum+1
def dense_layer(num_units):
return tf.keras.layers.Dense(
num_units,
activation=tf.keras.activations.relu,
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=2.0,
mode='fan_in',
distribution='truncated_normal'
)
)
dense_layers=[dense_layer(num_units) for num_units in fc_layer_params]
q_values_layer=tf.keras.layers.Dense(
num_actions,
activation=None,
kernel_initializer=tf.keras.initializers.RandomUniform(
minval=-0.03,
maxval=0.03
),
bias_initializer=tf.keras.initializers.Constant(-0.2)
)
q_net=sequential.Sequential(dense_layers+[q_values_layer])
optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate)
train_step_counter=tf.Variable(0)
agent= dqn_agent.DqnAgent(
train_env.time_step_spec(),
train_env.action_spec(),
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=train_step_counter
)
agent.initialize()
eval_policy = agent.policy
collect_policy = agent.collect_policy
random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),train_env.action_spec())
example_environment = tf_py_environment.TFPyEnvironment(
suite_gym.load('CartPole-v0')
)
time_step=example_environment.reset()
random_policy.action(time_step)
def compute_avg_return(environment,policy,num_episodes=10):
total_return =0.0
for _ in range(num_episodes):
time_step=environment.reset()
episode_return =0.0
while not time_step.is_last():
action_step=policy.action(time_step)
time_step=environment.step(action_step.action)
episode_return += time_step.reward
total_return+=episode_return
avg_return=total_return/num_episodes
return avg_return.numpy()[0]
compute_avg_return(eval_env,random_policy,num_eval_episodes)
table_name='uniform_table'
replay_buffer_signature = tensor_spec.from_spec(
agent.collect_data_spec
)
replay_buffer_signature = tensor_spec.add_outer_dim(
replay_buffer_signature
)
table=reverb.Table(
table_name,
max_size=replay_buffer_max_length,
sampler=reverb.selectors.Uniform(),
remover=reverb.selectors.Fifo(),
rate_limiter=reverb.rate_limiters.MinSize(1),
signature=replay_buffer_signature
)
reverb_server=reverb.Server([table])
replay_buffer= reverb_replay_buffer.ReverbReplayBuffer(
agent.collect_data_spec,
table_name=table_name,
sequence_length=2,
local_server=reverb_server
)
rb_observer=reverb_utils.ReverbAddTrajectoryObserver(
replay_buffer.py_client,
table_name,
sequence_length=2
)
dataset=replay_buffer.as_dataset(
num_parallel_calls=3,
sample_batch_size=batch_size,
num_steps=2
).prefetch(3)
iterator=iter(dataset)
# try:
# %%time
# except:
# pass
agent.train=common.function(agent.train)
agent.train_step_counter.assign(0)
avg_return = compute_avg_return(eval_env,agent.policy,num_eval_episodes)
returns =[avg_return]
time_step=train_py_env.reset()
collect_driver = py_driver.PyDriver(
env,
py_tf_eager_policy.PyTFEagerPolicy(
agent.collect_policy,use_tf_function=True),
[rb_observer],
max_steps=collect_steps_per_iteration
)
for _ in range(num_iterations):
time_step,_=collect_driver.run(time_step)
print('iterator next')
experience,unused_info=next(iterator)###<---it gets stuck here!!
train_loss=agent.train(experience).loss
step=agent.train_step_counter.numpy()
if step % log_interval ==0:
print('step = {0}: loss={1}'.format(step,train_loss))
if step % eval_interval==0:
avg_return = compute_avg_return(eval_env,agent.policy,num_eval_episodes)
print('step = {0}: Average Return = {1}'.format(step,avg_return))
returns.append(avg_return)
iterations=range(0,num_iterations+1, eval_interval)
The terminal output
(tf_tutorial) harry@harry-Aspire-A315-58:~/Documents/Reinforcement Learning/tf$ python intro.py
2024-12-17 14:43:38.622260: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-17 14:43:38.624297: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-12-17 14:43:38.651939: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-17 14:43:38.651970: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-17 14:43:38.652896: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-17 14:43:38.657333: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-12-17 14:43:38.657493: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-12-17 14:43:39.165679: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[reverb/cc/platform/tfrecord_checkpointer.cc:162] Initializing TFRecordCheckpointer in /tmp/tmpyir39eez.
[reverb/cc/platform/tfrecord_checkpointer.cc:565] Loading latest checkpoint from /tmp/tmpyir39eez
[reverb/cc/platform/default/server.cc:71] Started replay server on port 42883
iterator next
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (21034) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (21034) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (21034) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (21034) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (21034) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (21034) so Table uniform_table is accessed directly without gRPC.
conda list
# Name Version Build Channel
_libgcc_mutex 0.1 conda_forge conda-forge
_openmp_mutex 4.5 2_gnu conda-forge
alsa-lib 1.2.13 hb9d3cd8_0 conda-forge
asttokens 3.0.0 pyhd8ed1ab_1 conda-forge
brotli 1.1.0 hb9d3cd8_2 conda-forge
brotli-bin 1.1.0 hb9d3cd8_2 conda-forge
bzip2 1.0.8 h4bc722e_7 conda-forge
ca-certificates 2024.12.14 hbcca054_0 conda-forge
cairo 1.18.2 h3394656_1 conda-forge
certifi 2024.8.30 pyhd8ed1ab_0 conda-forge
contourpy 1.3.0 py39h74842e3_2 conda-forge
cycler 0.12.1 pyhd8ed1ab_1 conda-forge
cyrus-sasl 2.1.27 h54b06d7_7 conda-forge
dbus 1.13.6 h5008d03_3 conda-forge
decorator 5.1.1 pyhd8ed1ab_1 conda-forge
double-conversion 3.3.0 h59595ed_0 conda-forge
exceptiongroup 1.2.2 pyhd8ed1ab_1 conda-forge
executing 2.1.0 pyhd8ed1ab_1 conda-forge
expat 2.6.4 h5888daf_0 conda-forge
font-ttf-dejavu-sans-mono 2.37 hab24e00_0 conda-forge
font-ttf-inconsolata 3.000 h77eed37_0 conda-forge
font-ttf-source-code-pro 2.038 h77eed37_0 conda-forge
font-ttf-ubuntu 0.83 h77eed37_3 conda-forge
fontconfig 2.15.0 h7e30c49_1 conda-forge
fonts-conda-ecosystem 1 0 conda-forge
fonts-conda-forge 1 0 conda-forge
fonttools 4.55.3 py39h9399b63_0 conda-forge
freetype 2.12.1 h267a509_2 conda-forge
graphite2 1.3.13 h59595ed_1003 conda-forge
harfbuzz 9.0.0 hda332d3_1 conda-forge
icu 75.1 he02047a_0 conda-forge
imageio 2.4.0 pypi_0 pypi
importlib-resources 6.4.5 pyhd8ed1ab_1 conda-forge
importlib_resources 6.4.5 pyhd8ed1ab_1 conda-forge
ipython 8.18.1 pyh707e725_3 conda-forge
jedi 0.19.2 pyhd8ed1ab_1 conda-forge
keras 2.15.0 pypi_0 pypi
keyutils 1.6.1 h166bdaf_0 conda-forge
kiwisolver 1.4.7 py39h74842e3_0 conda-forge
krb5 1.21.3 h659f571_0 conda-forge
lcms2 2.16 hb7c19ff_0 conda-forge
ld_impl_linux-64 2.43 h712a8e2_2 conda-forge
lerc 4.0.0 h27087fc_0 conda-forge
libblas 3.9.0 25_linux64_openblas conda-forge
libbrotlicommon 1.1.0 hb9d3cd8_2 conda-forge
libbrotlidec 1.1.0 hb9d3cd8_2 conda-forge
libbrotlienc 1.1.0 hb9d3cd8_2 conda-forge
libcblas 3.9.0 25_linux64_openblas conda-forge
libclang-cpp19.1 19.1.5 default_hb5137d0_0 conda-forge
libclang13 19.1.5 default_h9c6a7e4_0 conda-forge
libcups 2.3.3 h4637d8d_4 conda-forge
libdeflate 1.22 hb9d3cd8_0 conda-forge
libdrm 2.4.124 hb9d3cd8_0 conda-forge
libedit 3.1.20191231 he28a2e2_2 conda-forge
libegl 1.7.0 ha4b6fd6_2 conda-forge
libexpat 2.6.4 h5888daf_0 conda-forge
libffi 3.4.2 h7f98852_5 conda-forge
libgcc 14.2.0 h77fa898_1 conda-forge
libgcc-ng 14.2.0 h69a702a_1 conda-forge
libgfortran 14.2.0 h69a702a_1 conda-forge
libgfortran5 14.2.0 hd5240d6_1 conda-forge
libgl 1.7.0 ha4b6fd6_2 conda-forge
libglib 2.82.2 h2ff4ddf_0 conda-forge
libglvnd 1.7.0 ha4b6fd6_2 conda-forge
libglx 1.7.0 ha4b6fd6_2 conda-forge
libgomp 14.2.0 h77fa898_1 conda-forge
libiconv 1.17 hd590300_2 conda-forge
libjpeg-turbo 3.0.0 hd590300_1 conda-forge
liblapack 3.9.0 25_linux64_openblas conda-forge
libllvm19 19.1.5 ha7bfdaf_0 conda-forge
liblzma 5.6.3 hb9d3cd8_1 conda-forge
libnsl 2.0.1 hd590300_0 conda-forge
libntlm 1.4 h7f98852_1002 conda-forge
libopenblas 0.3.28 pthreads_h94d23a6_1 conda-forge
libopengl 1.7.0 ha4b6fd6_2 conda-forge
libpciaccess 0.18 hd590300_0 conda-forge
libpng 1.6.44 hadc24fc_0 conda-forge
libpq 17.2 h3b95a9b_1 conda-forge
libsqlite 3.47.2 hee588c1_0 conda-forge
libstdcxx 14.2.0 hc0a3c3a_1 conda-forge
libstdcxx-ng 14.2.0 h4852527_1 conda-forge
libtiff 4.7.0 hc4654cb_2 conda-forge
libuuid 2.38.1 h0b41bf4_0 conda-forge
libwebp-base 1.4.0 hd590300_0 conda-forge
libxcb 1.17.0 h8a09558_0 conda-forge
libxcrypt 4.4.36 hd590300_1 conda-forge
libxkbcommon 1.7.0 h2c5496b_1 conda-forge
libxml2 2.13.5 h8d12d68_1 conda-forge
libxslt 1.1.39 h76b75d6_0 conda-forge
libzlib 1.3.1 hb9d3cd8_2 conda-forge
matplotlib 3.9.4 py39hf3d152e_0 conda-forge
matplotlib-base 3.9.4 py39h16632d1_0 conda-forge
matplotlib-inline 0.1.7 pyhd8ed1ab_1 conda-forge
ml-dtypes 0.3.2 pypi_0 pypi
munkres 1.1.4 pyh9f0ad1d_0 conda-forge
mysql-common 9.0.1 h266115a_3 conda-forge
mysql-libs 9.0.1 he0572af_3 conda-forge
ncurses 6.5 he02047a_1 conda-forge
numpy 2.0.2 py39h9cb892a_1 conda-forge
openjpeg 2.5.3 h5fbd93e_0 conda-forge
openldap 2.6.9 he970967_0 conda-forge
openssl 3.4.0 hb9d3cd8_0 conda-forge
packaging 24.2 pyhd8ed1ab_2 conda-forge
parso 0.8.4 pyhd8ed1ab_1 conda-forge
pcre2 10.44 hba22ea6_2 conda-forge
pexpect 4.9.0 pyhd8ed1ab_1 conda-forge
pickleshare 0.7.5 pyhd8ed1ab_1004 conda-forge
pillow 11.0.0 py39h538c539_0 conda-forge
pip 24.3.1 pyh8b19718_0 conda-forge
pixman 0.44.2 h29eaf8c_0 conda-forge
prompt-toolkit 3.0.48 pyha770c72_1 conda-forge
pthread-stubs 0.4 hb9d3cd8_1002 conda-forge
ptyprocess 0.7.0 pyhd8ed1ab_1 conda-forge
pure_eval 0.2.3 pyhd8ed1ab_1 conda-forge
pyglet 2.0.20 pypi_0 pypi
pygments 2.18.0 pyhd8ed1ab_1 conda-forge
pyparsing 3.2.0 pyhd8ed1ab_2 conda-forge
pyside6 6.8.1 py39h0383914_0 conda-forge
python 3.9.21 h9c0c6dc_1_cpython conda-forge
python-dateutil 2.9.0.post0 pyhff2d567_1 conda-forge
python_abi 3.9 5_cp39 conda-forge
pyvirtualdisplay 3.0 pypi_0 pypi
qhull 2020.2 h434a139_5 conda-forge
qt6-main 6.8.1 h9d28a51_0 conda-forge
readline 8.2 h8228510_1 conda-forge
setuptools 75.6.0 pyhff2d567_1 conda-forge
six 1.17.0 pyhd8ed1ab_0 conda-forge
stack_data 0.6.3 pyhd8ed1ab_1 conda-forge
tensorboard 2.15.2 pypi_0 pypi
tensorflow 2.15.1 pypi_0 pypi
tf-agents 0.19.0 pypi_0 pypi
tf-keras 2.18.0 pypi_0 pypi
tk 8.6.13 noxft_h4845f30_101 conda-forge
tornado 6.4.2 py39h8cd3c5a_0 conda-forge
traitlets 5.14.3 pyhd8ed1ab_1 conda-forge
typing_extensions 4.12.2 pyha770c72_1 conda-forge
tzdata 2024b hc8b5060_0 conda-forge
unicodedata2 15.1.0 py39h8cd3c5a_1 conda-forge
wayland 1.23.1 h3e06ad9_0 conda-forge
wcwidth 0.2.13 pyhd8ed1ab_1 conda-forge
wheel 0.45.1 pyhd8ed1ab_1 conda-forge
xcb-util 0.4.1 hb711507_2 conda-forge
xcb-util-cursor 0.1.5 hb9d3cd8_0 conda-forge
xcb-util-image 0.4.0 hb711507_2 conda-forge
xcb-util-keysyms 0.4.1 hb711507_0 conda-forge
xcb-util-renderutil 0.3.10 hb711507_0 conda-forge
xcb-util-wm 0.4.2 hb711507_0 conda-forge
xkeyboard-config 2.43 hb9d3cd8_0 conda-forge
xorg-libice 1.1.2 hb9d3cd8_0 conda-forge
xorg-libsm 1.2.5 he73a12e_0 conda-forge
xorg-libx11 1.8.10 h4f16b4b_1 conda-forge
xorg-libxau 1.0.12 hb9d3cd8_0 conda-forge
xorg-libxcomposite 0.4.6 hb9d3cd8_2 conda-forge
xorg-libxcursor 1.2.3 hb9d3cd8_0 conda-forge
xorg-libxdamage 1.1.6 hb9d3cd8_0 conda-forge
xorg-libxdmcp 1.1.5 hb9d3cd8_0 conda-forge
xorg-libxext 1.3.6 hb9d3cd8_0 conda-forge
xorg-libxfixes 6.0.1 hb9d3cd8_0 conda-forge
xorg-libxi 1.8.2 hb9d3cd8_0 conda-forge
xorg-libxrandr 1.5.4 hb9d3cd8_0 conda-forge
xorg-libxrender 0.9.12 hb9d3cd8_0 conda-forge
xorg-libxtst 1.2.5 hb9d3cd8_3 conda-forge
xorg-libxxf86vm 1.1.6 hb9d3cd8_0 conda-forge
zipp 3.21.0 pyhd8ed1ab_1 conda-forge
zstd 1.5.6 ha6fb4c9_0 conda-forge
Thanks!