Skip to content

Commit 6e850ca

Browse files
committed
first commit
1 parent 467b60a commit 6e850ca

12 files changed

+1116
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
*.pyc
2+
.DS_Store
3+
__init__.py

data.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import hyperparams as hp
2+
import pandas as pd
3+
from torch.utils.data import Dataset, DataLoader
4+
import os
5+
import librosa
6+
import numpy as np
7+
from Tacotron.text import text_to_sequence
8+
import collections
9+
from scipy import signal
10+
11+
class LJDatasets(Dataset):
12+
"""LJSpeech dataset."""
13+
14+
def __init__(self, csv_file, root_dir):
15+
"""
16+
Args:
17+
csv_file (string): Path to the csv file with annotations.
18+
root_dir (string): Directory with all the wavs.
19+
20+
"""
21+
self.landmarks_frame = pd.read_csv(csv_file, sep='|', header=None)
22+
self.root_dir = root_dir
23+
24+
def load_wav(self, filename):
25+
return librosa.load(filename, sr=hp.sample_rate)
26+
27+
def __len__(self):
28+
return len(self.landmarks_frame)
29+
30+
def __getitem__(self, idx):
31+
wav_name = os.path.join(self.root_dir, self.landmarks_frame.ix[idx, 0]) + '.wav'
32+
text = self.landmarks_frame.ix[idx, 1]
33+
text = np.asarray(text_to_sequence(text, [hp.cleaners]), dtype=np.int32)
34+
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
35+
sample = {'text': text, 'wav': wav}
36+
37+
return sample
38+
39+
def collate_fn(batch):
40+
41+
# Puts each data field into a tensor with outer dimension batch size
42+
if isinstance(batch[0], collections.Mapping):
43+
keys = list()
44+
45+
text = [d['text'] for d in batch]
46+
wav = [d['wav'] for d in batch]
47+
48+
# PAD sequences with largest length of the batch
49+
text = _prepare_data(text).astype(np.int32)
50+
wav = _prepare_data(wav)
51+
52+
magnitude = np.array([spectrogram(w) for w in wav])
53+
mel = np.array([melspectrogram(w) for w in wav])
54+
timesteps = mel.shape[-1]
55+
56+
# PAD with zeros that can be divided by outputs per step
57+
if timesteps % hp.outputs_per_step != 0:
58+
magnitude = _pad_per_step(magnitude)
59+
mel = _pad_per_step(mel)
60+
61+
return text, magnitude, mel
62+
63+
raise TypeError(("batch must contain tensors, numbers, dicts or lists; found {}"
64+
.format(type(batch[0]))))
65+
66+
# These pre-processing functions are referred from https://github.com/keithito/tacotron
67+
68+
_mel_basis = None
69+
70+
def save_wav(wav, path):
71+
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
72+
librosa.output.write_wav(path, wav.astype(np.int16), hp.sample_rate)
73+
74+
75+
def _linear_to_mel(spectrogram):
76+
global _mel_basis
77+
if _mel_basis is None:
78+
_mel_basis = _build_mel_basis()
79+
return np.dot(_mel_basis, spectrogram)
80+
81+
def _build_mel_basis():
82+
n_fft = (hp.num_freq - 1) * 2
83+
return librosa.filters.mel(hp.sample_rate, n_fft, n_mels=hp.num_mels)
84+
85+
def _normalize(S):
86+
return np.clip((S - hp.min_level_db) / -hp.min_level_db, 0, 1)
87+
88+
def _denormalize(S):
89+
return (np.clip(S, 0, 1) * -hp.min_level_db) + hp.min_level_db
90+
91+
def _stft_parameters():
92+
n_fft = (hp.num_freq - 1) * 2
93+
hop_length = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
94+
win_length = int(hp.frame_length_ms / 1000 * hp.sample_rate)
95+
return n_fft, hop_length, win_length
96+
97+
def _amp_to_db(x):
98+
return 20 * np.log10(np.maximum(1e-5, x))
99+
100+
def _db_to_amp(x):
101+
return np.power(10.0, x * 0.05)
102+
103+
def preemphasis(x):
104+
return signal.lfilter([1, -hp.preemphasis], [1], x)
105+
106+
107+
def inv_preemphasis(x):
108+
return signal.lfilter([1], [1, -hp.preemphasis], x)
109+
110+
111+
def spectrogram(y):
112+
D = _stft(preemphasis(y))
113+
S = _amp_to_db(np.abs(D)) - hp.ref_level_db
114+
return _normalize(S)
115+
116+
117+
def inv_spectrogram(spectrogram):
118+
'''Converts spectrogram to waveform using librosa'''
119+
120+
S = _denormalize(spectrogram)
121+
S = _db_to_amp(S + hp.ref_level_db) # Convert back to linear
122+
123+
return inv_preemphasis(_griffin_lim(S ** hp.power)) # Reconstruct phase
124+
125+
def _griffin_lim(S):
126+
'''librosa implementation of Griffin-Lim
127+
Based on https://github.com/librosa/librosa/issues/434
128+
'''
129+
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
130+
S_complex = np.abs(S).astype(np.complex)
131+
y = _istft(S_complex * angles)
132+
for i in range(hp.griffin_lim_iters):
133+
angles = np.exp(1j * np.angle(_stft(y)))
134+
y = _istft(S_complex * angles)
135+
return y
136+
137+
def _istft(y):
138+
_, hop_length, win_length = _stft_parameters()
139+
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
140+
141+
142+
def melspectrogram(y):
143+
D = _stft(preemphasis(y))
144+
S = _amp_to_db(_linear_to_mel(np.abs(D)))
145+
return _normalize(S)
146+
147+
def _stft(y):
148+
n_fft, hop_length, win_length = _stft_parameters()
149+
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
150+
151+
def find_endpoint(wav, threshold_db=-40, min_silence_sec=0.8):
152+
window_length = int(hp.sample_rate * min_silence_sec)
153+
hop_length = int(window_length / 4)
154+
threshold = _db_to_amp(threshold_db)
155+
for x in range(hop_length, len(wav) - window_length, hop_length):
156+
if np.max(wav[x:x+window_length]) < threshold:
157+
return x + hop_length
158+
return len(wav)
159+
160+
def _pad_data(x, length):
161+
_pad = 0
162+
return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
163+
164+
def _prepare_data(inputs):
165+
max_len = max((len(x) for x in inputs))
166+
return np.stack([_pad_data(x, max_len) for x in inputs])
167+
168+
def _pad_per_step(inputs):
169+
timesteps = inputs.shape[-1]
170+
return np.pad(inputs, [[0,0],[0,0],[0, hp.outputs_per_step - (timesteps % hp.outputs_per_step)]], mode='constant', constant_values=0.0)
171+
172+
def get_param_size(model):
173+
params = 0
174+
for p in model.parameters():
175+
tmp = 1
176+
for x in p.size():
177+
tmp *= x
178+
params += tmp
179+
return params
180+
181+
def get_dataset():
182+
return LJDatasets(os.path.join(hp.data_path,'metadata.csv'), os.path.join(hp.data_path,'wavs'))

hyperparams.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Audio
2+
3+
num_mels = 80
4+
num_freq = 1024
5+
sample_rate = 20000
6+
frame_length_ms = 50.
7+
frame_shift_ms = 12.5
8+
preemphasis = 0.97
9+
min_level_db = -100
10+
ref_level_db = 20
11+
hidden_size = 128
12+
embedding_size = 256
13+
14+
max_iters = 200
15+
griffin_lim_iters = 60
16+
power = 1.5
17+
outputs_per_step = 5
18+
teacher_forcing_ratio = 1.0
19+
20+
epochs = 10000
21+
lr = 0.001
22+
decay_step = [500000, 1000000, 2000000]
23+
log_step = 100
24+
save_step = 2000
25+
26+
cleaners='english_cleaners'
27+
28+
data_path = '../data'
29+
output_path = './result'
30+
checkpoint_path = './model_new'

0 commit comments

Comments
 (0)