개요
AudioCraft Custom은 Facebook Research의 AudioCraft 프레임워크를 기반으로 개발된 고급 AI 오디오 생성 플랫폼입니다. 이 프로젝트는 텍스트-투-뮤직, 텍스트-투-오디오 생성부터 고급 오디오 분석까지 포괄하는 완전한 솔루션을 제공합니다. 본 포스트에서는 이 프로젝트의 아키텍처, 핵심 구성 요소, 그리고 실제 구현을 상세히 분석하겠습니다.
1. 프로젝트 아키텍처 개요
1.1 전체 구조
audiocraft-custom/
├── audiocraft/ # 핵심 AudioCraft 라이브러리
│ ├── models/ # AI 모델 구현
│ │ ├── musicgen.py # 음악 생성 모델
│ │ ├── audiogen.py # 일반 오디오 생성
│ │ ├── encodec.py # 오디오 코덱
│ │ └── multibanddiffusion.py # 다중 밴드 확산
│ ├── adversarial/ # 적대적 네트워크
│ │ └── discriminators/ # 판별자 모델들
│ ├── data/ # 데이터 처리
│ ├── modules/ # 공통 모듈
│ └── solvers/ # 훈련 솔버
├── api/ # REST API 서버
├── demos/ # Jupyter 노트북 데모
├── config/ # 설정 파일들
└── docker/ # 컨테이너화
1.2 핵심 기능 영역
- 음악 생성 (MusicGen): 텍스트 프롬프트로 음악 생성
- 오디오 생성 (AudioGen): 일반 사운드 이펙트 생성
- 오디오 코덱 (EnCodec): 고품질 오디오 압축/복원
- 적대적 분석: 실제/생성 오디오 판별
- REST API: 웹 서비스 인터페이스
- Docker 배포: 컨테이너 기반 배포
2. 핵심 모델 구현 분석
2.1 MusicGen - 음악 생성 모델
# audiocraft/models/musicgen.py
class MusicGen(BaseGenModel):
"""MusicGen main model with convenient generation API.
Args:
name (str): name of the model.
compression_model (CompressionModel): Compression model
used to map audio to invertible discrete representations.
lm (LMModel): Language model over discrete representations.
max_duration (float, optional): maximum duration the model can produce,
otherwise, inferred from the training params.
"""
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
max_duration: tp.Optional[float] = None, **kwargs):
super().__init__(name, compression_model, lm, max_duration, **kwargs)
@staticmethod
def get_pretrained(name: str = 'facebook/musicgen-medium', device=None):
"""Return pretrained model, we provide four models:
- facebook/musicgen-small (300M), text to music,
# see: https://huggingface.co/facebook/musicgen-small
- facebook/musicgen-medium (1.5B), text to music,
# see: https://huggingface.co/facebook/musicgen-medium
- facebook/musicgen-melody (1.5B) text to music and text+melody to music,
# see: https://huggingface.co/facebook/musicgen-melody
- facebook/musicgen-large (3.3B), text to music,
# see: https://huggingface.co/facebook/musicgen-large
"""
MusicGen의 핵심 특징:
- Transformer 기반: 1.5B~3.3B 파라미터 규모
- 조건부 생성: 텍스트 및 멜로디 조건 지원
- 고품질 출력: 32kHz 샘플링 레이트
- 제어 가능: 온도, top-k, CFG 등 다양한 생성 파라미터
2.2 AudioGen - 일반 오디오 생성
# audiocraft/models/audiogen.py
class AudioGen(MusicGen):
"""AudioGen model for text-to-sound generation.
This is a thin wrapper around MusicGen as both models have the same architecture.
"""
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
max_duration: tp.Optional[float] = None, **kwargs):
super().__init__(name, compression_model, lm, max_duration, **kwargs)
@staticmethod
def get_pretrained(name: str = 'facebook/audiogen-medium', device=None):
"""Return pretrained AudioGen model."""
AudioGen vs MusicGen 차이점:
- 훈련 데이터: 음악 대신 일반 사운드 이펙트
- 용도: 환경음, 효과음, 자연음 등
- 모델 크기: Medium (1.5B) 모델 제공
- 아키텍처: MusicGen과 동일한 구조
2.3 EnCodec - 오디오 압축 코덱
# audiocraft/models/encodec.py
class CompressionModel(nn.Module):
"""Base class for all compression model (e.g, EnCodec, AudioMAE, DAC etc.).
Args:
sample_rate (int): Sample rate of the audio.
channels (int): Number of audio channels.
normalize (bool): Whether to normalize the audio.
segment (float, optional): Segment length for processing.
overlap (float, optional): Overlap between segments.
"""
def encode(self, x: torch.Tensor) -> tp.List[EncodedFrame]:
"""Encode audio into discrete tokens."""
def decode(self, encoded_frames: tp.List[EncodedFrame]) -> torch.Tensor:
"""Decode tokens back to audio."""
EnCodec의 핵심 기능:
- 벡터 양자화: 연속 오디오를 이산 토큰으로 변환
- 고품질 복원: 높은 품질의 오디오 재구성
- 다중 해상도: 다양한 비트레이트 지원
- 실시간 처리: 스트리밍 가능한 처리 속도
3. 적대적 네트워크 시스템
3.1 Multi-Period Discriminator (MPD)
# audiocraft/adversarial/discriminators/mpd.py
class PeriodDiscriminator(nn.Module):
"""Period sub-discriminator.
Args:
period (int): Period between samples of audio.
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
n_layers (int): Number of convolutional layers.
kernel_sizes (list of int): Kernel sizes for convolutions.
stride (int): Stride for convolutions.
filters (int): Initial number of filters in convolutions.
filters_scale (int): Multiplier of number of filters as we increase depth.
max_filters (int): Maximum number of filters.
"""
def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1,
n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3,
filters: int = 8, filters_scale: int = 4, max_filters: int = 1024,
norm: str = 'weight_norm', activation: str = 'LeakyReLU',
activation_params: dict = {'negative_slope': 0.2}):
super().__init__()
self.period = period
self.n_layers = n_layers
self.activation = getattr(torch.nn, activation)(**activation_params)
self.convs = nn.ModuleList()
def forward(self, x):
"""Forward pass through period discriminator."""
# Reshape input according to period
b, c, t = x.shape
if t % self.period != 0:
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
fmap = []
for conv in self.convs:
x = conv(x)
x = self.activation(x)
fmap.append(x)
return x, fmap
class MultiPeriodDiscriminator(MultiDiscriminator):
"""Multi-Period Discriminator (MPD) from HiFi-GAN."""
def __init__(self, periods: tp.List[int] = [2, 3, 5, 7, 11], **kwargs):
discriminators = [PeriodDiscriminator(p, **kwargs) for p in periods]
super().__init__(discriminators)
3.2 Multi-Scale Discriminator (MSD)
# audiocraft/adversarial/discriminators/msd.py
class ScaleDiscriminator(nn.Module):
"""Scale sub-discriminator."""
def __init__(self, norm: str = 'spectral_norm', **kwargs):
super().__init__()
self.norm = norm
self.convs = nn.ModuleList([
NormConv1d(1, 128, 15, 1, padding=7),
NormConv1d(128, 128, 41, 2, groups=4, padding=20),
NormConv1d(128, 256, 41, 2, groups=16, padding=20),
NormConv1d(256, 512, 41, 4, groups=16, padding=20),
NormConv1d(512, 1024, 41, 4, groups=16, padding=20),
NormConv1d(1024, 1024, 41, 1, groups=16, padding=20),
NormConv1d(1024, 1024, 5, 1, padding=2),
])
self.conv_post = NormConv1d(1024, 1, 3, 1, padding=1)
class MultiScaleDiscriminator(MultiDiscriminator):
"""Multi-Scale Discriminator (MSD) from HiFi-GAN."""
def __init__(self, scales: tp.List[int] = [1, 2, 4], **kwargs):
discriminators = []
for scale in scales:
discriminators.append(ScaleDiscriminator(**kwargs))
super().__init__(discriminators, pools=[nn.AvgPool1d(scale * 2, scale, padding=scale)
if scale > 1 else nn.Identity() for scale in scales])
3.3 Multi-Scale STFT Discriminator (MS-STFT-D)
# audiocraft/adversarial/discriminators/msstftd.py
class STFTDiscriminator(nn.Module):
"""STFT sub-discriminator."""
def __init__(self, n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, **kwargs):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.register_buffer('window', torch.hann_window(win_length))
def forward(self, x):
"""Apply STFT and discriminate in frequency domain."""
x = torch.stft(x.squeeze(1), n_fft=self.n_fft, hop_length=self.hop_length,
win_length=self.win_length, window=self.window, return_complex=True)
x = torch.view_as_real(x)
x = rearrange(x, 'b f t c -> b c f t')
# Apply 2D convolutions in time-frequency domain
fmap = []
for conv in self.convs:
x = conv(x)
x = self.activation(x)
fmap.append(x)
return x, fmap
class MultiScaleSTFTDiscriminator(MultiDiscriminator):
"""Multi-Scale STFT Discriminator for frequency domain analysis."""
def __init__(self, n_ffts: tp.List[int] = [1024, 2048, 4096],
hop_lengths: tp.List[int] = [120, 240, 480], **kwargs):
discriminators = [STFTDiscriminator(n_fft, hop_length, **kwargs)
for n_fft, hop_length in zip(n_ffts, hop_lengths)]
super().__init__(discriminators)
적대적 시스템의 특징:
- 다중 관점 분석: 시간, 주파수, 주기 도메인
- 계층적 특징: 다양한 스케일의 특징 추출
- 안정적 훈련: 다중 판별자로 모드 붕괴 방지
- 품질 보장: 실제 오디오와 유사한 품질 달성
4. REST API 서버 구현
4.1 FastAPI 기반 서버
# api/main.py
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel
app = FastAPI(
title="AudioCraft API",
description="AudioCraft의 모든 모델을 REST API로 제공하는 서비스",
version="1.0.0"
)
# 모델 초기화
models = {
"musicgen": MusicGen.get_pretrained("facebook/musicgen-small"),
"audiogen": AudioGen.get_pretrained("facebook/audiogen-medium"),
"encodec": EncodecModel.get_pretrained("facebook/encodec_24khz"),
"multiband": MultiBandDiffusion.get_pretrained("facebook/multiband-diffusion")
}
# 판별자 초기화
discriminators = {
"mpd": MultiPeriodDiscriminator(periods=[2, 3, 5, 7, 11], channels=32, kernel_size=5),
"msd": MultiScaleDiscriminator(scales=[1, 2, 4], channels=32, kernel_size=5),
"msstftd": MultiScaleSTFTDiscriminator(n_ffts=[1024, 2048, 4096], hop_lengths=[120, 240, 480], channels=32)
}
4.2 음악 생성 API
class TextToAudioRequest(BaseModel):
"""텍스트-오디오 생성 요청 모델"""
text: str
duration: float = 10.0
temperature: float = 1.0
top_k: int = 250
top_p: float = 0.0
cfg_coef: float = 3.0
@app.post("/generate/music", response_class=FileResponse)
async def generate_music(request: TextToAudioRequest):
"""
텍스트 프롬프트를 사용하여 음악을 생성합니다.
"""
try:
model = models["musicgen"]
model.set_generation_params(
duration=request.duration,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
cfg_coef=request.cfg_coef
)
wav = model.generate([request.text])
# 임시 파일로 저장
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
torchaudio.save(tmp.name, wav.cpu(), 32000)
return tmp.name
except Exception as e:
raise HTTPException(status_code=500, detail=f"음악 생성 중 오류 발생: {str(e)}")
4.3 오디오 분석 API
class AudioAnalysisResponse(BaseModel):
"""오디오 분석 결과를 위한 응답 모델"""
mpd_score: float
msd_score: float
msstftd_score: float
feature_maps: List[List[float]]
is_real: bool
@app.post("/analyze", response_model=AudioAnalysisResponse)
async def analyze_audio(
audio_file: UploadFile = File(...),
threshold: float = 0.5
):
"""
오디오 파일을 분석하여 각 판별자의 결과를 반환합니다.
"""
try:
audio_data = await audio_file.read()
waveform = process_audio(audio_data)
with torch.no_grad():
# MPD 분석
mpd_logits, mpd_features = discriminators["mpd"](waveform)
mpd_score = torch.mean(torch.sigmoid(mpd_logits[0])).item()
# MSD 분석
msd_logits, msd_features = discriminators["msd"](waveform)
msd_score = torch.mean(torch.sigmoid(msd_logits[0])).item()
# MS-STFT-D 분석
msstftd_logits, msstftd_features = discriminators["msstftd"](waveform)
msstftd_score = torch.mean(torch.sigmoid(msstftd_logits[0])).item()
# 특징 맵 추출
feature_maps = []
for features in [mpd_features, msd_features, msstftd_features]:
for feat in features:
feature_maps.append(feat.mean(dim=1).cpu().numpy().tolist())
is_real = (mpd_score + msd_score + msstftd_score) / 3 > threshold
return AudioAnalysisResponse(
mpd_score=mpd_score,
msd_score=msd_score,
msstftd_score=msstftd_score,
feature_maps=feature_maps,
is_real=is_real
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"분석 중 오류 발생: {str(e)}")
5. Docker 컨테이너화
5.1 Dockerfile 분석
FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
# 작업 디렉토리 설정
WORKDIR /workspace
# 시스템 의존성 설치
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y \
ffmpeg \
libsndfile1 \
build-essential \
&& rm -rf /var/lib/apt/lists/*
# Python 가상환경 생성 및 활성화
RUN python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# 애플리케이션 파일 복사
COPY . .
# Python 의존성 설치
RUN pip install --no-cache-dir -r requirements.txt
# audiocraft 패키지 설치
RUN pip install -e .
# 환경 변수 설정
ENV PYTHONPATH=/workspace
ENV HUGGING_FACE_HUB_TOKEN=${HUGGING_FACE_HUB_TOKEN}
# FastAPI 서버 포트 노출
EXPOSE 8000
# FastAPI 서버 실행
CMD ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "8000"]
Docker 설정의 특징:
- PyTorch 기반: CUDA 12.1 지원
- 시스템 의존성: FFmpeg, libsndfile 포함
- 가상환경: 격리된 Python 환경
- 환경 변수: Hugging Face 토큰 지원
- 자동 시작: uvicorn 서버 자동 실행
6. 데모 및 활용 예제
6.1 Jupyter 노트북 데모
# demos/musicgen_demo.ipynb
from audiocraft.models import MusicGen
from audiocraft.models import MultiBandDiffusion
USE_DIFFUSION_DECODER = False
# Using small model, better results would be obtained with `medium` or `large`.
model = MusicGen.get_pretrained('facebook/musicgen-small')
if USE_DIFFUSION_DECODER:
mbd = MultiBandDiffusion.get_mbd_musicgen()
# 생성 파라미터 설정
model.set_generation_params(
use_sampling=True,
top_k=250,
duration=30
)
# 텍스트 조건부 생성
output = model.generate(
descriptions=[
'80s pop track with bassy drums and synth',
'90s rock song with loud guitars and heavy drums',
'Progressive rock drum and bass solo',
'Punk Rock song with loud drum and power guitar',
'Bluesy guitar instrumental with soulful licks and a driving rhythm section',
'Jazz Funk song with slap bass and powerful saxophone',
],
progress=True
)
display_audio(output, sample_rate=32000)
6.2 음악 연속 생성
# 기존 오디오를 기반으로 연속 생성
import torchaudio
prompt_waveform, prompt_sr = torchaudio.load("../assets/bach.mp3")
prompt_duration = 2
prompt_waveform = prompt_waveform[..., :int(prompt_duration * prompt_sr)]
output = model.generate_continuation(
prompt_waveform,
prompt_sample_rate=prompt_sr,
progress=True,
return_tokens=True
)
display_audio(output[0], sample_rate=32000)
if USE_DIFFUSION_DECODER:
out_diffusion = mbd.tokens_to_wav(output[1])
display_audio(out_diffusion, sample_rate=32000)
7. 설정 시스템
7.1 기본 설정
# config/config.yaml
defaults:
- _self_
- dset: default
- solver: default
device: cuda
dtype: float32
autocast: false
autocast_dtype: bfloat16
seed: 2036
show: false
continue_from:
execute_only:
execute_inplace: false
benchmark_no_load: false
efficient_attention_backend: torch
num_threads: 1
mp_start_method: forkserver
label:
# logging parameters
logging:
level: info
log_file: null
log_tensorboard: true
log_wandb: false
7.2 모델별 설정
MusicGen 설정:
- 모델 크기: Small (300M), Medium (1.5B), Large (3.3B)
- 조건부 입력: 텍스트, 멜로디
- 생성 길이: 최대 30초
- 샘플링: Top-k, Top-p, Temperature 제어
AudioGen 설정:
- 특화 분야: 환경음, 효과음
- 품질: 32kHz 고품질 오디오
- 지속 시간: 다양한 길이 지원
8. 성능 최적화 및 확장성
8.1 메모리 최적화
# 모델 로딩 최적화
@lru_cache(maxsize=None)
def load_model_cached(model_name: str):
"""캐시된 모델 로딩으로 메모리 효율성 향상"""
if model_name == "musicgen":
return MusicGen.get_pretrained("facebook/musicgen-small")
elif model_name == "audiogen":
return AudioGen.get_pretrained("facebook/audiogen-medium")
# ...
# GPU 메모리 관리
def manage_gpu_memory():
"""GPU 메모리 정리"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
8.2 배치 처리
def batch_generate_music(prompts: List[str], batch_size: int = 4):
"""배치 단위 음악 생성으로 처리량 향상"""
results = []
for i in range(0, len(prompts), batch_size):
batch = prompts[i:i+batch_size]
with torch.no_grad():
output = model.generate(batch)
results.extend(output)
return results
8.3 비동기 처리
import asyncio
from concurrent.futures import ThreadPoolExecutor
async def async_generate_audio(request: TextToAudioRequest):
"""비동기 오디오 생성"""
loop = asyncio.get_event_loop()
with ThreadPoolExecutor() as executor:
result = await loop.run_in_executor(
executor,
lambda: model.generate([request.text])
)
return result
9. 실무 활용 시나리오
9.1 음악 제작 스튜디오
# 음악 제작을 위한 고급 워크플로우
class MusicProductionPipeline:
def __init__(self):
self.musicgen = MusicGen.get_pretrained('facebook/musicgen-large')
self.audiogen = AudioGen.get_pretrained('facebook/audiogen-medium')
self.encodec = EncodecModel.get_pretrained('facebook/encodec_24khz')
def create_song_structure(self, sections: Dict[str, str]):
"""섹션별 음악 생성"""
results = {}
for section, prompt in sections.items():
results[section] = self.musicgen.generate([prompt])
return results
def add_sound_effects(self, base_audio: torch.Tensor, effects: List[str]):
"""사운드 이펙트 추가"""
fx_audio = []
for effect in effects:
fx = self.audiogen.generate([effect])
fx_audio.append(fx)
return self.mix_audio(base_audio, fx_audio)
def compress_for_distribution(self, audio: torch.Tensor):
"""배포용 압축"""
codes = self.encodec.encode(audio)
compressed = self.encodec.decode(codes)
return compressed
9.2 게임 오디오 시스템
class GameAudioSystem:
"""게임용 동적 오디오 생성 시스템"""
def __init__(self):
self.audiogen = AudioGen.get_pretrained('facebook/audiogen-medium')
self.sound_cache = {}
def generate_ambient_sound(self, environment: str):
"""환경별 배경음 생성"""
prompts = {
"forest": "gentle forest ambiance with birds chirping and leaves rustling",
"ocean": "calm ocean waves with seagull sounds",
"city": "urban city ambiance with distant traffic and footsteps",
"dungeon": "dark dungeon atmosphere with water drops and wind"
}
if environment not in self.sound_cache:
audio = self.audiogen.generate([prompts[environment]])
self.sound_cache[environment] = audio
return self.sound_cache[environment]
def generate_dynamic_music(self, game_state: Dict):
"""게임 상태에 따른 동적 음악"""
tension = game_state.get('tension', 0.5)
location = game_state.get('location', 'neutral')
if tension > 0.8:
prompt = f"intense battle music for {location} with dramatic orchestration"
elif tension > 0.5:
prompt = f"suspenseful {location} music with building tension"
else:
prompt = f"peaceful {location} ambient music"
return self.musicgen.generate([prompt])
9.3 팟캐스트 자동화
class PodcastAudioProcessor:
"""팟캐스트 제작 자동화"""
def create_intro_music(self, podcast_theme: str):
"""팟캐스트 인트로 음악 생성"""
prompt = f"upbeat podcast intro music for {podcast_theme} show, 15 seconds"
return self.musicgen.generate([prompt])
def generate_transition_sounds(self, count: int = 5):
"""전환 사운드 생성"""
transitions = []
prompts = [
"smooth podcast transition sound",
"gentle chime for section break",
"soft whoosh transition effect",
"minimalist transition tone",
"clean section divider sound"
]
for prompt in prompts[:count]:
sound = self.audiogen.generate([prompt])
transitions.append(sound)
return transitions
10. 한계점 및 개선 방향
10.1 현재 한계점
- 계산 복잡도: 고품질 생성을 위한 높은 GPU 요구사항
- 생성 시간: 실시간 생성의 어려움
- 제어 정밀도: 세밀한 음악적 요소 제어의 한계
- 일관성: 긴 오디오에서의 일관성 유지 문제
10.2 개선 방향
# 미래 개선 방향 (예시)
class NextGenAudioCraft:
"""차세대 AudioCraft 시스템"""
def __init__(self):
self.streaming_generator = StreamingMusicGen()
self.fine_control = FinegrainedController()
self.quality_enhancer = AudioQualityEnhancer()
def real_time_generation(self, prompt: str):
"""실시간 스트리밍 생성"""
# 청크 단위 실시간 생성
for chunk in self.streaming_generator.generate_stream(prompt):
yield chunk
def style_transfer(self, content_audio: torch.Tensor, style_prompt: str):
"""오디오 스타일 전송"""
# 기존 오디오의 스타일 변경
return self.fine_control.transfer_style(content_audio, style_prompt)
def adaptive_quality(self, audio: torch.Tensor, target_quality: str):
"""적응적 품질 향상"""
# 사용 목적에 따른 품질 최적화
return self.quality_enhancer.enhance(audio, target_quality)
결론
AudioCraft Custom 프로젝트는 최첨단 AI 오디오 생성 기술을 실용적인 플랫폼으로 구현한 훌륭한 사례입니다.
핵심 성과:
- 완전한 파이프라인: 생성부터 분석까지 통합 솔루션
- 확장 가능한 아키텍처: 모듈화된 설계로 쉬운 확장
- 실용적인 API: RESTful 인터페이스로 쉬운 통합
- Docker 지원: 간편한 배포와 확장성
이 프로젝트는 음악 제작, 게임 개발, 미디어 제작 등 다양한 분야에서 AI 오디오 생성 기술의 실제 활용 가능성을 보여줍니다. 앞으로 실시간 생성, 더 정밀한 제어, 향상된 품질 등의 개선을 통해 더욱 강력한 오디오 생성 플랫폼으로 발전할 것으로 기대됩니다.