Skip to content
Snippets Groups Projects
Commit e443ceac authored by Tomas Kasl's avatar Tomas Kasl
Browse files

prepare for submission

parent e03555c1
Branches main
No related tags found
No related merge requests found
......@@ -9,7 +9,7 @@ model:
prenorm: true
embedding: false
layer:
N: 64 # 96 # 64
N: 64 # 96
train:
epochs: 10
......@@ -21,11 +21,7 @@ train:
robust_alg: 2
robust_iters: 10
robust_step_size: 0.05
checkpoint: false
suffix: null #".checkpoint" # String to use for checkpoint suffix
sample: null # Sample during validation with desired prefix length
# Pass in 'wandb.mode=online' to turn on wandb logging
wandb:
mode: disabled
project: s4
......
......@@ -131,9 +131,6 @@ def create_cifar_classification_dataset(bsz=128):
tf = transforms.Compose(
[
transforms.ToTensor(),
# transforms.Normalize(
# (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
# ),
transforms.Normalize(mean=0.5, std=0.5),
transforms.Lambda(lambda x: x.view(IN_DIM, SEQ_LENGTH).t()),
]
......@@ -229,7 +226,7 @@ def normalize(features):
f_min = np.min(features)
f_max = np.max(features)
features = (features-f_min) / (f_max - f_min)
features = 2* features - 1
features = 2*features - 1
return features
data_path = "/data/UrbanSound8K/audio"
......@@ -240,6 +237,7 @@ def create_sound_samples_classification_dataset(bsz=4, series=True, soundlen=40_
metadata = pd.read_csv(metadata_path)
for index, row in metadata.iterrows():
file_path = os.path.join(data_path, f"fold{row['fold']}", f"{row['slice_file_name']}")
try:
array = wav_to_array(file_path)
except:
......@@ -304,15 +302,12 @@ def create_sound_samples_classification_dataset(bsz=4, series=True, soundlen=40_
)
return trainloader, testloader, n_classes, SEQ_LENGTH, IN_DIM, X_bounds
def create_sound_samples_classification_dataset_spectrogram(bsz=4, series=True, soundlen=44_000, regularization=False, binary=True):
def create_sound_samples_classification_dataset_spectrogram(bsz=4, series=True, soundlen=40_000, regularization=False, binary=True):
features = []
labels = []
metadata = pd.read_csv(metadata_path)
for index, row in metadata.iterrows():
#print("loading file", index)
if 5900 > index > 5100:
continue
file_path = os.path.join(data_path, f"fold{row['fold']}", f"{row['slice_file_name']}")
try:
array = wav_to_array(file_path)
......@@ -451,8 +446,8 @@ class Conv1DTransform:
convolved_sample = convolved_sample.swapaxes(1,2).squeeze(0)
return convolved_sample
#, labels=["left", "right"]
def create_speech_command_classification(bsz=4, series=True, soundlen=16_000, regularization=False, labels=["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]):
def create_speech_command_classification(bsz=4, series=True, soundlen=16_000, regularization=False,
labels=["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]):
train_set = SubsetSC("training")
test_set = SubsetSC("testing")
train_X, train_Y = gather_dataset(train_set, classes_to_keep=labels, len_limit=soundlen, sample_limit=108000)
......
......@@ -389,23 +389,3 @@ def init(x):
def hippo_initializer(N):
Lambda, P, B, _ = make_DPLR_HiPPO(N)
return init(Lambda.real), init(Lambda.imag), init(P), init(B)
def sample(model, params, prime, cache, x, start, end, rng):
def loop(i, cur):
x, rng, cache = cur
r, rng = jax.random.split(rng)
out, vars = model.apply(
{"params": params, "prime": prime, "cache": cache},
x[:, np.arange(1, 2) * i],
mutable=["cache"],
)
def update(x, out):
p = jax.random.categorical(r, out[0])
x = x.at[i + 1, 0].set(p)
return x
x = jax.vmap(update)(x, out)
return x, rng, vars["cache"].unfreeze()
return jax.lax.fori_loop(start, end, jax.jit(loop), (x, rng, cache))[0]
......@@ -40,6 +40,7 @@ def generate_model_cls(
)
return model_cls
#
# Following is implementation of the neccesary NN training functions
#
@partial(np.vectorize, signature="(c),()->()")
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment