티스토리 뷰

반응형

전에 SCE-TTS로 친구의 TTS를 만든 적이 있었다.

https://lemongreen.tistory.com/13

 

SCE-TTS 버그 수정판과 TTS 제작 경험

https://sce-tts.github.io/#/v2/index SCE-TTS: 내 목소리로 TTS 만들기 문서를 불러오고 있습니다... sce-tts.github.io 심심해서 SCE-TTS를 만져보고 있었는데, 여러가지 버그가 많았습니다. 근데 프로젝트 개발도

lemongreen.tistory.com

 

살짝 아쉬웠던 부분은, 버그를 떠나서 TTS의 퀄리티가 조금 아쉬웠다. train 양이 부족한 건지, 모델이 안 좋은건지, 데이터셋이 모자란건지 노이즈도 있었고, 살짝 기계의 느낌이 많이 났다.

 

그러던 도중, SCE-TTS의 repo를 뒤져보다 branch에 v3의 흔적이 있길래, 좀 찾아봤는데 Piper라는 TTS 라이브러리를 사용하고 있었다. 상당히 설정도 간단하고, 성능도 좋은 것 같아서 시도하게 되었다.

 

앞으로는 삽질의 향연이다.

 

중요: 코딩 지식이 없는 사람은 안하는 걸 추천합니다. 너무 힘들거든요....

 

문제 1. Piper는 한국어 지원이 없었다.

https://github.com/rhasspy/piper

 

GitHub - rhasspy/piper: A fast, local neural text to speech system

A fast, local neural text to speech system. Contribute to rhasspy/piper development by creating an account on GitHub.

github.com

piper를 그대로 훈련에 쓰려고 했는데, 한국어 지원이 없었다.. 하지만 Piper가 VITS라는 모델에 기반하고 있다고 README에 적혀있길래, VITS의 한국어 트레이닝이 가능한 버전을 찾아 나섰다.

 

다행히도 있었다. README도 친절하고 example도 있어서 기반 지식이 있는 분들은 쉽게 따라할 수 있을 것이다.

https://github.com/ouor/vits

 

GitHub - ouor/vits: VITS implementation of Japanese, Chinese, Korean, Sanskrit and Thai

VITS implementation of Japanese, Chinese, Korean, Sanskrit and Thai - ouor/vits

github.com

 

다행히 예전에 녹음한 TTS의 데이터셋과 완벽하게 호환되는 듯 했다. (LJSpeech) 전처리만 거치고 바로 트레이닝에 들어갔다.

 

문제 2. 트레이닝이 왜 안돼니

 

전에 콜랩으로 훈련하다가, 콜랩 특유의 불안정성 때문에 콜랩을 갖다버리고 runpod으로 갈아타게 되었다. 훨씬 빠르고, 좋고, 싸다. 여러분도 꼭 써보시라.

 

https://runpod.io 

 

Rent Cloud GPUs from $0.2/hour

Global Interoperability Select from 30+ regions across North America, Europe, and South America.

www.runpod.io

 

문제는 다른 곳에 있었다. runpod에 있는 최신 사양의 GPU인 4090과 pytorch 1.13.1 이미지로 작업할려니, CUDA 11.8을 깔아야한다더라. 문제는 CUDA 11.8을 지원하는 torch는 최소 ^2.0.0 버전부터였다. 우리가 사용하는 VITS repo는 2.0 지원이 없는 듯 했다.

 

결국 3090으로 다운그레이드 하고 다시 시작했다. 아니 근데 콜랩에선 잘만 되던게 에러를 뿜는 것이었다.

illegal memory error... 아마 메모리 부족으로 터지는 것이 확실했다. 3090이 메모리가 부족해서 터지는 것이 어이가 없었지만 대안을 찾아나섰다. 아마 체크포인트 로딩이나 세이브 중에 load_state_dict 라는 함수를 호출하는 중에 메모리 부족으로 터지는 듯 했다.

 

https://discuss.pytorch.org/t/load-state-dict-causes-memory-leak/36189/3

 

Load_state_dict causes memory leak

I’ll try that now and then post results here.

discuss.pytorch.org

 

다행히도 파이토치 커뮤니티에 해결 방법이 있어, 모든 load_state_dict를 CPU로 옮겨서 로딩하고 GPU로 옮기고, 불필요한 부분은 del도 하고, cache도 지우고.. 별 짓을 다했다.

 

근데 결국 해결하기 가장 좋은 법은 config.json에서 batch size를 줄이는 것이었다. 바로 해결되더라...아나..

 

3. 마지막 단계 ONNX로 최적화

제일 큰 문제가 다가왔다. 체크포인트를 그대로 쓸 순 없으니 ONNX로 export를 해야하는데, 지금 쓰고 있는 VITS repo와 piper repo의 onnx export 스크립트가 호환되지 않았다.

 

애초에 트레이닝 코드가 조금 다르니까 안되는 게 당연하지만, 머신러닝 관련 코드를 직접 짜보지 않은 나에게는 적잖이 당황할 수 밖에 없었다.

 

모든 지식을 총동원해서 ONNX를 export 하는 법을 찾고 있었다. piper에서 몇 줄 가져오고... 기존 리포 infer 예제에서 몇 줄 가져오고.. 파이토치 문서에서 가져오고... 참고한 문서는 아래에 추가해놓겠습니다..

 

https://github.com/rhasspy/piper/blob/master/src/python/piper_train/export_onnx.py

 

https://github.com/ouor/vits/blob/main/infer.ipynb

 

https://tutorials.pytorch.kr/advanced/super_resolution_with_onnxruntime.html

 

(선택) PyTorch 모델을 ONNX으로 변환하고 ONNX 런타임에서 실행하기

이 튜토리얼에서는 어떻게 PyTorch에서 정의된 모델을 ONNX 형식으로 변환하고 또 어떻게 그 변환된 모델을 ONNX 런타임에서 실행할 수 있는지에 대해 알아보도록 하겠습니다. ONNX 런타임은 ONNX 모델

tutorials.pytorch.kr

 

 

몇시간의 삽질 끝에.. 겨우 해냈다. 내가 스스로 ONNX 변환 코드를 짜다니. 내가 다 뿌듯하더라 ㅋㅋㅋ

 

설마 필요하실 분들을 위해 남겨놓겠습니다.

import IPython.display as ipd
import torch
import commons
import utils
from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence
from scipy.io.wavfile import write


checkpoint_path = ''
config_path = ''
destination_path = 'wow.onnx'

hps = utils.get_hparams_from_file(config_path)
spk_count = hps.data.n_speakers

model_g = SynthesizerTrn(
            len(symbols),
            hps.data.filter_length // 2 + 1,
            hps.train.segment_size // hps.data.hop_length,
            n_speakers=hps.data.n_speakers,
            **hps.model)

model_g.eval()
utils.load_checkpoint(checkpoint_path, model_g, None)

def get_text(text):
    text_norm = text_to_sequence(text, hps.data.text_cleaners)
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    text_norm = torch.LongTensor(text_norm)
    return text_norm

with torch.no_grad():
    model_g.dec.remove_weight_norm()
    
txt = get_text("학습은 잘 마치셨나요? 좋은 결과가 있길 바래요.")
input_text = txt.unsqueeze(0)
input_text_len = torch.LongTensor([txt.size(0)])
sid = torch.LongTensor([0])

arg = (output_text, output_text_len)

def infer_forward(text, text_lengths):
    audio = model_g.infer(
        text,
        text_lengths,
        noise_scale=.667,
        length_scale=1,
        noise_scale_w=0.8,
        sid=torch.LongTensor([0]),
    )[0][0,0]

    return audio

model_g.forward = infer_forward
model_g.cpu()

output = infer_forward(input_text, input_text_len).data.float().numpy()
print(output)
write(f'infer/onnx.wav', hps.data.sampling_rate, output)
ipd.display(ipd.Audio(output, rate=hps.data.sampling_rate, normalize=False))

torch.onnx.export(model_g, arg, destination_path,
    verbose=False,
    opset_version=15,
    input_names=["text", "text_lengths"],
    output_names=["output"],
    dynamic_axes={
        "text": {0: "batch_size", 1: "phonemes"},
        "text_lengths": {0: "batch_size"},
        "output": {0: "batch_size", 1: "time"},
    },
)

 

ouor/vits 리포에 의존적이고, 하드코딩도 많으니 알아서 조금씩 수정하셔서 쓰시면 될 것 같습니다.

 

ONNX 사용은 아래를 참고해주세용

 

import onnxruntime
import IPython.display as ipd
import torch
import commons
import utils
from text.symbols import symbols
from text import text_to_sequence
import numpy as np
from scipy.io.wavfile import write

ort_session = onnxruntime.InferenceSession("wow.onnx")
checkpoint_path = ''
config_path = ''
hps = utils.get_hparams_from_file(config_path)

def get_text(text):
    text_norm = text_to_sequence(text, hps.data.text_cleaners)
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    # text_norm = torch.LongTensor(text_norm)
    text_norm = np.array(text_norm, dtype=np.int64)
    return text_norm

txt = get_text("학습은 잘 마치셨나요? 좋은 결과가 있길 바래요.")
input_text = np.expand_dims(txt, axis=0)
input_text_len = np.array([txt.shape[0]], dtype=np.int64)

arg = {
    "text": input_text,
    "text_lengths": input_text_len,
}

# ONNX 런타임에서 계산된 결과값
ort_outs = ort_session.run(["output"], arg)
output = ort_outs[0]

write(f'infer/onnx.wav', hps.data.sampling_rate, output)
ipd.display(ipd.Audio(output, rate=hps.data.sampling_rate, normalize=False))

 

onnx-simpilfier 사용하시면 용량도 줄고, 성능도 좋아집니다. 저의 경우에는 onnx 로딩에 0.9~초, inference에 0.7~초 걸리는 것 같더라구요.

 

https://github.com/daquexian/onnx-simplifier

 

GitHub - daquexian/onnx-simplifier: Simplify your onnx model

Simplify your onnx model. Contribute to daquexian/onnx-simplifier development by creating an account on GitHub.

github.com

 

 

코드들이 아주 더럽습니다..ㅋㅋㅋ 급해서 메모장에서 바로바로 갈겨 쓴거라 그래요.

 

쓰다보니 생각보다 커스텀한 구간이 많더라구요. inference 부분만 구현해놨긴 한데, 나중에 제대로 짜고 확장해서 tts-server도 만들어야겠습니다.. 자바스크립트로도 onnxruntime이 있다면 훨씬 좋았을텐데..ㅋㅋ 차피 numpy가 필요해서 안될 것 같네요 ㅋㅋ

 

 

 

+ 테스트 파일들은 example을 참고해주세요. 테스트 결과 파일은 제 목소리도 아닌 친구의 목소리기 때문에 공유해드릴 수가 없습니다.

반응형
댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/11   »
1 2
3 4 5 6 7 8 9
10 11 12 13 14 15 16
17 18 19 20 21 22 23
24 25 26 27 28 29 30
글 보관함