Create local_response_norm.py
Browse files- local_response_norm.py +7 -0
local_response_norm.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class LocalResponseNorm(nn.Module):
|
2 |
+
@nn.compact
|
3 |
+
def __call__(
|
4 |
+
self,
|
5 |
+
value: jax.Array
|
6 |
+
) -> jax.Array:
|
7 |
+
return value / jnp.repeat(jnp.expand_dims((1e-8 + (value**2).mean(axis=-1))**0.5, axis=-1), repeats=value.shape[-1], axis=-1)
|