I’m a student in natural sciences, working on a project that uses machine learning. I’m working with a neural network model created by a different research group, implemented using haiku.
In their original article, when talking about their model hyperparameters, the didn’t include the size of the input layer while describing their MLP shape, but only the hidden and output layer sizes. For instance if the input is N=2000 and subsequent layers are 256, 64, 4
, (with (4,)
being the output shape) then they report their MLP shape as (256,64,4)
even in their paper.
To me this was confusing, as the parsing from 2000 to 256 has to have a weight matrix proportional to the input size. Later, I realized that in haiku
modules, the input layer is adapted to the size of the input, and you typically never pass the size of the input in defining the MLP itself. This is different from what I experienced with tensorflow
/keras
in my ML courses. I was surprised and even tested that — different input shapes output different trainable parameters. Please correct me if I misunderstand: But isn’t that going to give a totally different model if you had an N=1000 input vs. N=2000 input?
My question to the informed machine learning experts here: While reporting layers, should one write “four layers of shape (2000, 256, 64, 4)
” or “Three layers of shape (256, 64, 4)
“?
P.S. If anyone is curious, I’ve tested that using different input sizes actually leads to different weight parameters in haiku
.
As a MWE example
import haiku as hk
import jax.numpy as jnp
import numpy as np
import jax
class MyLinear1(hk.Module):
def __init__(self, output_size, name=None):
super().__init__(name=name)
self.output_size = output_size
def __call__(self, x):
j, k = x.shape[-1], self.output_size
w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.ones)
return jnp.dot(x, w) + b
def _forward_fn_linear1(x):
module = MyLinear1(output_size=2)
return module(x)
forward_linear1 = hk.transform(_forward_fn_linear1)
rng_key = jax.random.PRNGKey(42)
dummy_x = jnp.ones([8])
params = forward_linear1.init(rng=rng_key, x=dummy_x)
print(params)
If you use dummy_x = jnp.ones([8])
it gives
[-0.4785112 , -0.38034892],
[-0.41137823, -0.22265594],
[-0.43343404, 0.21691099],
[-0.18514387, -0.10827615],
[ 0.3682926 , -0.1418969 ],
[ 0.10915945, 0.4389233 ],
[-0.07725035, 0.08247987]], dtype=float32), 'b': Array([1., 1.], dtype=float32)}}
while dummy_x = jnp.ones[1])
gives
{'my_linear1': {'w': Array([[ 1.51595 , -0.23353337]], dtype=float32), 'b': Array([1., 1.], dtype=float32)}}
cartilage is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.