개요
Open-Sora의 유틸리티 모듈은 AI 비디오 생성 모델의 핵심 기능을 지원하는 다양한 헬퍼 함수와 클래스들을 포함하고 있습니다. 이번 포스트에서는 추론 엔진, 학습 파이프라인, 체크포인트 관리, 메모리 최적화 등 Open-Sora 시스템의 실제 동작을 뒷받침하는 핵심 유틸리티들을 상세히 분석하겠습니다.
1. 유틸리티 모듈 구조 개요
1.1 전체 구조
opensora/utils/
├── inference.py # 추론 엔진 및 조건부 생성
├── train.py # 학습 파이프라인 및 최적화
├── ckpt.py # 체크포인트 관리 시스템
├── misc.py # 메모리 모니터링 및 기타 유틸리티
├── optimizer.py # 옵티마이저 및 스케줄러
├── sampling.py # 샘플링 및 텍스트 처리
├── config.py # 설정 관리
├── logger.py # 로깅 시스템
├── prompt_refine.py # 프롬프트 개선
└── cai.py # ColossalAI 통합
1.2 핵심 기능 영역
- 추론 시스템: 조건부 생성 및 샘플 처리
- 학습 파이프라인: 분산 학습 및 최적화
- 체크포인트 관리: 모델 저장/로드 시스템
- 메모리 관리: 메모리 모니터링 및 최적화
- 헬퍼 함수: 다양한 보조 기능들
2. 추론 시스템 상세 분석
2.1 조건부 비디오 생성
# opensora/utils/inference.py
def prepare_inference_condition(
z: torch.Tensor,
mask_cond: str,
ref_list: list[list[torch.Tensor]] = None,
causal: bool = True,
) -> torch.Tensor:
"""
추론을 위한 시각적 조건 준비
Args:
z: 잠재 노이즈 텐서 [B, C, T, H, W]
mask_cond: 조건 타입 ("i2v_head", "v2v_head", "t2v" 등)
ref_list: 참조 미디어 리스트
causal: Causal VAE 사용 여부
Returns:
masks, masked_z: 마스크와 조건부 잠재 벡터
"""
B, C, T, H, W = z.shape
# 마스크 및 조건부 텐서 초기화
masks = torch.zeros(B, 1, T, H, W)
masked_z = torch.zeros(B, C, T, H, W)
if ref_list is None:
assert mask_cond == "t2v", f"reference is required for {mask_cond}"
for i in range(B):
ref = ref_list[i]
if ref is not None and T > 1: # 비디오 생성
if mask_cond == "i2v_head": # 첫 프레임 조건
masks[i, :, 0, :, :] = 1
masked_z[i, :, 0, :, :] = ref[0][:, 0, :, :]
elif mask_cond == "i2v_tail": # 마지막 프레임 조건
masks[i, :, -1, :, :] = 1
masked_z[i, :, -1, :, :] = ref[-1][:, -1, :, :]
elif mask_cond == "v2v_head": # 비디오 시작 부분 조건
k = 8 + int(causal)
masks[i, :, :k, :, :] = 1
masked_z[i, :, :k, :, :] = ref[0][:, :k, :, :]
elif mask_cond == "v2v_tail": # 비디오 끝 부분 조건
k = 8 + int(causal)
masks[i, :, -k:, :, :] = 1
masked_z[i, :, -k:, :, :] = ref[0][:, -k:, :, :]
elif mask_cond == "i2v_loop": # 루프 비디오 (시작+끝 조건)
masks[i, :, 0, :, :] = 1
masks[i, :, -1, :, :] = 1
masked_z[i, :, 0, :, :] = ref[0][:, 0, :, :]
masked_z[i, :, -1, :, :] = ref[-1][:, -1, :, :]
else:
assert mask_cond == "t2v", f"Unknown mask condition {mask_cond}"
masks = masks.to(z.device, z.dtype)
masked_z = masked_z.to(z.device, z.dtype)
return masks, masked_z
조건부 생성 타입:
- i2v_head: 이미지 → 비디오 (첫 프레임 고정)
- i2v_tail: 이미지 → 비디오 (마지막 프레임 고정)
- i2v_loop: 루프 비디오 (시작과 끝 프레임 고정)
- v2v_head/tail: 비디오 → 비디오 (일부 프레임 조건)
- t2v: 텍스트 → 비디오 (무조건부)
2.2 참조 미디어 수집
def collect_references_batch(
reference_paths: list[str],
cond_type: str,
model_ae: nn.Module,
image_size: tuple[int, int],
is_causal=False,
):
"""
배치 단위로 참조 미디어 수집 및 인코딩
"""
refs_x = []
device = next(model_ae.parameters()).device
dtype = next(model_ae.parameters()).dtype
for reference_path in reference_paths:
if reference_path == "":
refs_x.append(None)
continue
ref_path = reference_path.split(";")
ref = []
if "v2v" in cond_type:
# 비디오-투-비디오: 연속 프레임 처리
r = read_from_path(ref_path[0], image_size, transform_name="resize_crop")
actual_t = r.size(1)
target_t = 64 if (actual_t >= 64 and "easy" in cond_type) else 32
if is_causal:
target_t += 1
assert actual_t >= target_t, f"need at least {target_t} reference frames"
if "head" in cond_type:
r = r[:, :target_t]
elif "tail" in cond_type:
r = r[:, -target_t:]
r_x = model_ae.encode(r.unsqueeze(0).to(device, dtype))
ref.append(r_x.squeeze(0))
elif cond_type == "i2v_head":
# 이미지-투-비디오: 첫 프레임
r = read_from_path(ref_path[0], image_size, transform_name="resize_crop")
r = r[:, :1] # 첫 프레임만
r_x = model_ae.encode(r.unsqueeze(0).to(device, dtype))
ref.append(r_x.squeeze(0))
elif cond_type == "i2v_loop":
# 루프 비디오: 첫 프레임 + 마지막 프레임
r_head = read_from_path(ref_path[0], image_size, transform_name="resize_crop")
r_head = r_head[:, :1]
r_x_head = model_ae.encode(r_head.unsqueeze(0).to(device, dtype))
ref.append(r_x_head.squeeze(0))
r_tail = read_from_path(ref_path[-1], image_size, transform_name="resize_crop")
r_tail = r_tail[:, -1:]
r_x_tail = model_ae.encode(r_tail.unsqueeze(0).to(device, dtype))
ref.append(r_x_tail.squeeze(0))
refs_x.append(ref)
return refs_x
2.3 샘플 처리 및 저장
def process_and_save(
x: torch.Tensor,
batch: dict,
cfg: dict,
sub_dir: str,
generate_sampling_option,
epoch: int,
start_index: int,
saving: bool = True,
):
"""
생성된 샘플 처리 및 디스크 저장
"""
fallback_name = cfg.dataset.data_path.split("/")[-1].split(".")[0]
prompt_as_path = cfg.get("prompt_as_path", False)
fps_save = cfg.get("fps_save", 16)
save_dir = cfg.save_dir
names = batch["name"] if "name" in batch else [None] * len(x)
indices = batch["index"] if "index" in batch else [None] * len(x)
prompts = batch["text"]
ret_names = []
is_image = generate_sampling_option.num_frames == 1
for img, name, index, prompt in zip(x, names, indices, prompts):
# 저장 경로 생성
save_path = get_save_path_name(
save_dir, sub_dir,
save_prefix=cfg.get("save_prefix", ""),
name=name, fallback_name=fallback_name,
index=index, num_sample_pos=epoch,
prompt_as_path=prompt_as_path, prompt=prompt,
)
ret_name = get_names_from_path(save_path)
ret_names.append(ret_name)
if saving:
# 프롬프트 텍스트 저장
with open(save_path + ".txt", "w", encoding="utf-8") as f:
f.write(prompt)
# 샘플 저장 (비디오/이미지)
save_sample(img, save_path=save_path, fps=fps_save)
# T2I2V를 위한 이미지 리사이징
if (cfg.get("use_t2i2v", False) and is_image and
generate_sampling_option.resolution != generate_sampling_option.resized_resolution):
height, width = get_image_size(
generate_sampling_option.resized_resolution,
generate_sampling_option.aspect_ratio
)
rescale_image_by_path(save_path + ".png", width, height)
return ret_names
3. 학습 파이프라인 분석
3.1 분산 학습 환경 설정
# opensora/utils/train.py
def setup_device() -> tuple[torch.device, DistCoordinator]:
"""
디바이스 및 분산 코디네이터 설정
"""
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
# 매우 긴 타임아웃 설정 (24시간)
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
coordinator = DistCoordinator()
device = get_current_device()
return device, coordinator
def create_colossalai_plugin(
plugin: str,
dtype: str,
grad_clip: float,
**kwargs,
) -> LowLevelZeroPlugin | HybridParallelPlugin:
"""
ColossalAI 플러그인 생성
"""
plugin_kwargs = dict(
precision=dtype,
initial_scale=2**16,
max_norm=grad_clip,
overlap_allgather=True,
cast_inputs=False,
reduce_bucket_size_in_m=20,
)
plugin_kwargs.update(kwargs)
sp_size = plugin_kwargs.get("sp_size", 1)
if plugin == "zero1" or plugin == "zero2":
assert sp_size == 1, "Zero plugin does not support sequence parallelism"
stage = 1 if plugin == "zero1" else 2
plugin = LowLevelZeroPlugin(stage=stage, **plugin_kwargs)
set_data_parallel_group(dist.group.WORLD)
elif plugin == "hybrid":
plugin_kwargs["find_unused_parameters"] = True
plugin_kwargs["enable_metadata_cache"] = False
custom_policy = plugin_kwargs.pop("custom_policy", None)
if custom_policy is not None:
custom_policy = custom_policy()
plugin = HybridParallelPlugin(custom_policy=custom_policy, **plugin_kwargs)
set_tensor_parallel_group(plugin.tp_group)
set_sequence_parallel_group(plugin.sp_group)
set_data_parallel_group(plugin.dp_group)
else:
raise ValueError(f"Unknown plugin {plugin}")
return plugin
3.2 시각적 조건 준비 (학습용)
def prepare_visual_condition_causal(
x: torch.Tensor,
condition_config: dict,
model_ae: torch.nn.Module
) -> torch.Tensor:
"""
Causal VAE를 위한 시각적 조건 준비
"""
B = x.shape[0]
C = model_ae.cfg.latent_channels
T, H, W = model_ae.get_latent_size(x.shape[-3:])
masks = torch.zeros(B, 1, T, H, W).to(x.device, x.dtype)
latent = torch.zeros(B, C, T, H, W).to(x.device, x.dtype)
x_0 = torch.zeros(B, C, T, H, W).to(x.device, x.dtype)
if T > 1: # 비디오
# 짧은 비디오에 적용되지 않는 조건들 제거
if T <= (32 // model_ae.time_compression_ratio) + 1:
condition_config.pop("v2v_head", None)
condition_config.pop("v2v_tail", None)
condition_config.pop("v2v_head_easy", None)
condition_config.pop("v2v_tail_easy", None)
mask_cond_options = list(condition_config.keys())
mask_cond_weights = list(condition_config.values())
for i in range(B):
# 확률에 따른 마스크 조건 랜덤 선택
mask_cond = random.choices(mask_cond_options, weights=mask_cond_weights, k=1)[0]
if mask_cond == "i2v_head":
masks[i, :, 0, :, :] = 1
x_0[i] = model_ae.encode(x[i].unsqueeze(0))[0]
latent[i, :, :1, :, :] = model_ae.encode(x[i, :, :1, :, :].unsqueeze(0))
elif mask_cond == "i2v_loop":
masks[i, :, 0, :, :] = 1
masks[i, :, -1, :, :] = 1
x_0[i] = model_ae.encode(x[i].unsqueeze(0))[0]
latent[i, :, :1, :, :] = model_ae.encode(x[i, :, :1, :, :].unsqueeze(0))
latent[i, :, -1:, :, :] = model_ae.encode(x[i, :, -1:, :, :].unsqueeze(0))
elif "v2v_head" in mask_cond:
ref_t = 33 if not "easy" in mask_cond else 65
assert (ref_t - 1) % model_ae.time_compression_ratio == 0
conditioned_t = (ref_t - 1) // model_ae.time_compression_ratio + 1
masks[i, :, :conditioned_t, :, :] = 1
x_0[i] = model_ae.encode(x[i].unsqueeze(0))[0]
latent[i, :, :conditioned_t, :, :] = model_ae.encode(x[i, :, :ref_t, :, :].unsqueeze(0))
elif "v2v_tail" in mask_cond:
ref_t = 33 if not "easy" in mask_cond else 65
conditioned_t = (ref_t - 1) // model_ae.time_compression_ratio + 1
masks[i, :, -conditioned_t:, :, :] = 1
x_0[i] = model_ae.encode(x[i].unsqueeze(0))[0]
latent[i, :, -conditioned_t:, :, :] = model_ae.encode(x[i, :, -ref_t:, :, :].unsqueeze(0))
else:
assert mask_cond == "t2v", f"Unknown mask condition {mask_cond}"
x_0[i] = model_ae.encode(x[i].unsqueeze(0))[0]
else: # 이미지
x_0 = model_ae.encode(x)
latent = masks * latent
cond = torch.cat((masks, latent), dim=1)
return x_0, cond
3.3 EMA 업데이트 시스템
@torch.no_grad()
def update_ema(
ema_model: torch.nn.Module,
model: torch.nn.Module,
optimizer=None,
decay: float = 0.9999,
sharded: bool = True
):
"""
EMA 모델을 현재 모델 방향으로 업데이트
"""
ema_params = OrderedDict(ema_model.named_parameters())
model_params = OrderedDict(model.named_parameters())
for name, param in model_params.items():
if name == "pos_embed":
continue
if not param.requires_grad:
continue
if not sharded:
param_data = param.data
ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)
else:
if param.data.dtype != torch.float32:
param_id = id(param)
master_param = optimizer.get_working_to_master_map()[param_id]
param_data = master_param.data
else:
param_data = param.data
ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)
3.4 배치 손실 계산
def get_batch_loss(model_pred, v_t, masks=None):
"""
I2V를 위한 배치 손실 계산 (생성된 프레임만 포함)
"""
if masks is not None:
num_frames, height, width = masks.shape[-3:]
masks = masks[:, :, 0, 0] # [B, T]만 보기
# 텐서 재배열
model_pred = rearrange(
model_pred,
"b (t h w) (c ph pw) -> b c t (h ph) (w pw)",
h=height // 2, w=width // 2, t=num_frames, ph=2, pw=2,
)
v_t = rearrange(
v_t,
"b (t h w) (c ph pw) -> b c t (h ph) (w pw)",
h=height // 2, w=width // 2, t=num_frames, ph=2, pw=2,
)
batch_loss = 0
for i in range(model_pred.size(0)):
pred_val = model_pred[i]
target_val = v_t[i]
# 앞/뒤 패딩이 있는 경우 제외
if masks[i][0] == 1 and (not 1 in masks[i][1:-1]):
pred_val = pred_val[:, 1:]
target_val = target_val[:, 1:]
if masks[i][-1] == 1 and (not 1 in masks[i][1:-1]):
pred_val = pred_val[:, :-1]
target_val = target_val[:, :-1]
batch_loss += F.mse_loss(pred_val.float(), target_val.float(), reduction="mean")
loss = batch_loss / model_pred.size(0)
else:
loss = F.mse_loss(model_pred.float(), v_t.float(), reduction="mean")
return loss
4. 체크포인트 관리 시스템
4.1 다양한 체크포인트 로드
# opensora/utils/ckpt.py
def load_checkpoint(
model: nn.Module,
path: str,
cache_dir: str = None,
device_map: torch.device | str = "cpu",
cai_model_name: str = "model",
strict: bool = False,
rename_keys: dict = None,
) -> nn.Module:
"""
다양한 형태의 체크포인트 로드 지원:
1. Hugging Face safetensors
2. 로컬 .pt/.pth 파일
3. ColossalAI 샤드 체크포인트
"""
if not os.path.exists(path):
log_message(f"Checkpoint not found at {path}, trying to download from Hugging Face Hub")
path = load_from_hf_hub(path, cache_dir)
assert os.path.exists(path), f"Could not find checkpoint at {path}"
log_message(f"Loading checkpoint from {path}")
if path.endswith(".safetensors"):
ckpt = load_file(path, device='cpu')
# 키 이름 변경 (fine-tuning 지원)
if rename_keys is not None:
renamed_ckpt = {}
for old_key, v in ckpt.items():
new_key = old_key
for old_key_prefix, new_key_prefix in rename_keys.items():
if old_key_prefix in old_key:
new_key = old_key.replace(old_key_prefix, new_key_prefix)
print(f"Renamed {old_key} to {new_key} in the loaded state_dict")
break
renamed_ckpt[new_key] = v
ckpt = renamed_ckpt
missing, unexpected = model.load_state_dict(ckpt, strict=strict)
print_load_warning(missing, unexpected)
elif path.endswith(".pt") or path.endswith(".pth"):
ckpt = torch.load(path, map_location=device_map)
missing, unexpected = model.load_state_dict(ckpt, strict=strict)
print_load_warning(missing, unexpected)
else:
assert os.path.isdir(path), f"Invalid checkpoint path: {path}"
load_from_sharded_state_dict(model, path, model_name=cai_model_name, strict=strict)
return model
4.2 고성능 체크포인트 I/O
class CheckpointIO:
"""
비동기 I/O를 지원하는 고성능 체크포인트 관리자
"""
def __init__(self, n_write_entries: int = 32):
self.n_write_entries = n_write_entries
self.writer: Optional[AsyncFileWriter] = None
self.pinned_state_dict: Optional[Dict[str, torch.Tensor]] = None
self.master_pinned_state_dict: Optional[Dict[str, torch.Tensor]] = None
self.master_writer: Optional[AsyncFileWriter] = None
def save(
self,
booster: Booster,
save_dir: str,
model: nn.Module = None,
ema: nn.Module = None,
optimizer: Optimizer = None,
lr_scheduler: _LRScheduler = None,
sampler=None,
epoch: int = None,
step: int = None,
global_step: int = None,
batch_size: int = None,
lora: bool = False,
actual_update_step: int = None,
ema_shape_dict: dict = None,
async_io: bool = True,
include_master_weights: bool = False,
) -> str:
"""
포괄적인 체크포인트 저장
"""
self._sync_io()
save_dir = os.path.join(save_dir, f"epoch{epoch}-global_step{actual_update_step}")
os.environ["TENSORNVME_DEBUG_LOG"] = os.path.join(save_dir, "async_file_io.log")
# 모델 저장
if model is not None:
if not lora:
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
booster.save_model(
model, os.path.join(save_dir, "model"),
shard=True, use_safetensors=True, size_per_shard=4096,
use_async=async_io,
)
else:
os.makedirs(os.path.join(save_dir, "lora"), exist_ok=True)
booster.save_lora_as_pretrained(model, os.path.join(save_dir, "lora"))
# 옵티마이저 저장
if optimizer is not None:
booster.save_optimizer(
optimizer, os.path.join(save_dir, "optimizer"),
shard=True, size_per_shard=4096, use_async=async_io
)
if include_master_weights:
self._prepare_master_pinned_state_dict(model, optimizer)
master_weights_gathering(model, optimizer, self.master_pinned_state_dict)
# EMA 모델 저장
if ema is not None:
self._prepare_pinned_state_dict(ema, ema_shape_dict)
model_gathering(ema, ema_shape_dict, self.pinned_state_dict)
# 메타데이터 저장 (rank 0만)
if dist.get_rank() == 0:
running_states = {
"epoch": epoch,
"step": step,
"global_step": global_step,
"batch_size": batch_size,
"actual_update_step": actual_update_step,
}
save_json(running_states, os.path.join(save_dir, "running_states.json"))
if ema is not None:
if async_io:
self.writer = async_save(os.path.join(save_dir, "ema.safetensors"), self.pinned_state_dict)
else:
torch.save(ema.state_dict(), os.path.join(save_dir, "ema.pt"))
if optimizer is not None and include_master_weights:
self.master_writer = async_save(
os.path.join(save_dir, "master.safetensors"), self.master_pinned_state_dict
)
dist.barrier()
return save_dir
4.3 분산 모델 수집
def model_gathering(model: torch.nn.Module, model_shape_dict: dict, pinned_state_dict: dict) -> None:
"""
여러 GPU에서 모델 파라미터 수집
"""
global_rank = dist.get_rank()
global_size = dist.get_world_size()
params = set()
for name, param in model.named_parameters():
params.add(name)
# 모든 rank에서 파라미터 수집
all_params = [torch.empty_like(param.data) for _ in range(global_size)]
dist.all_gather(all_params, param.data, group=dist.group.WORLD)
if int(global_rank) == 0:
all_params = torch.cat(all_params)
gathered_param = remove_padding(all_params, model_shape_dict[name]).view(model_shape_dict[name])
pinned_state_dict[name].copy_(gathered_param)
# 버퍼 처리 (rank 0만)
if int(global_rank) == 0:
for k, v in model.state_dict(keep_vars=True).items():
if k not in params:
pinned_state_dict[k].copy_(v)
dist.barrier()
5. 메모리 관리 및 모니터링
5.1 CUDA 메모리 모니터링
# opensora/utils/misc.py
GIGABYTE = 1024**3
def log_cuda_memory(stage: str = None):
"""
현재 CUDA 메모리 사용량 로깅
"""
text = "CUDA memory usage"
if stage is not None:
text += f" at {stage}"
log_message(text + ": %.1f GB", torch.cuda.memory_allocated() / GIGABYTE)
def log_cuda_max_memory(stage: str = None):
"""
최대 CUDA 메모리 사용량 로깅
"""
torch.cuda.synchronize()
max_memory_allocated = torch.cuda.max_memory_allocated()
max_memory_reserved = torch.cuda.max_memory_reserved()
log_message("CUDA max memory allocated at " + stage + ": %.1f GB", max_memory_allocated / GIGABYTE)
log_message("CUDA max memory reserved at " + stage + ": %.1f GB", max_memory_reserved / GIGABYTE)
def get_model_numel(model: torch.nn.Module) -> tuple[int, int]:
"""
모델 파라미터 수 계산
"""
num_params = 0
num_params_trainable = 0
for p in model.parameters():
num_params += p.numel()
if p.requires_grad:
num_params_trainable += p.numel()
return num_params, num_params_trainable
def log_model_params(model: nn.Module):
"""
모델 파라미터 수 로깅
"""
num_params, num_params_trainable = get_model_numel(model)
log_message(f"Model parameters: {num_params:,} total, {num_params_trainable:,} trainable")
5.2 Tensorboard 및 로깅
def create_tensorboard_writer(exp_dir: str) -> SummaryWriter:
"""
Tensorboard writer 생성
"""
tensorboard_dir = f"{exp_dir}/tensorboard"
os.makedirs(tensorboard_dir, exist_ok=True)
writer = SummaryWriter(tensorboard_dir)
return writer
6. 옵티마이저 및 스케줄러
6.1 옵티마이저 생성
# opensora/utils/optimizer.py
def create_optimizer(
model: torch.nn.Module,
optimizer_config: dict,
) -> torch.optim.Optimizer:
"""
옵티마이저 생성
"""
optimizer_name = optimizer_config.pop("cls", "HybridAdam")
if optimizer_name == "HybridAdam":
optimizer_cls = HybridAdam
else:
raise ValueError(f"Unknown optimizer: {optimizer_name}")
optimizer = optimizer_cls(
filter(lambda p: p.requires_grad, model.parameters()),
**optimizer_config,
)
return optimizer
6.2 학습률 스케줄러
class LinearWarmupLR(_LRScheduler):
"""
선형 웜업 학습률 스케줄러
"""
def __init__(self, optimizer, initial_lr=0, warmup_steps: int = 0, last_epoch: int = -1):
self.initial_lr = initial_lr
self.warmup_steps = warmup_steps
super().__init__(optimizer, last_epoch=last_epoch)
def get_lr(self):
if self.last_epoch < self.warmup_steps:
# 웜업 단계: 선형 증가
return [
self.initial_lr + (self.last_epoch + 1) / (self.warmup_steps + 1) * (lr - self.initial_lr)
for lr in self.base_lrs
]
else:
# 웜업 완료: 기본 학습률 사용
return self.base_lrs
def create_lr_scheduler(
optimizer: torch.optim.Optimizer,
num_steps_per_epoch: int,
epochs: int = 1000,
warmup_steps: int | None = None,
use_cosine_scheduler: bool = False,
initial_lr: float = 1e-6,
) -> _LRScheduler | None:
"""
학습률 스케줄러 생성
"""
if warmup_steps is None and not use_cosine_scheduler:
lr_scheduler = None
elif use_cosine_scheduler:
lr_scheduler = CosineAnnealingWarmupLR(
optimizer,
total_steps=num_steps_per_epoch * epochs,
warmup_steps=warmup_steps,
)
else:
lr_scheduler = LinearWarmupLR(optimizer, initial_lr=initial_lr, warmup_steps=warmup_steps)
return lr_scheduler
7. 실제 사용 예제
7.1 추론 파이프라인 예제
# 추론을 위한 조건 준비 예제
def inference_example():
# 설정
device = torch.device("cuda")
batch_size = 2
channels = 16
frames = 64
height = 32
width = 32
# 잠재 노이즈 텐서
z = torch.randn(batch_size, channels, frames, height, width).to(device)
# 참조 이미지/비디오 (예시)
ref_list = [
[torch.randn(channels, 1, height, width).to(device)], # 첫 번째 배치: 이미지
[torch.randn(channels, 8, height, width).to(device)] # 두 번째 배치: 비디오
]
# I2V 조건 준비
masks, masked_z = prepare_inference_condition(
z=z,
mask_cond="i2v_head",
ref_list=ref_list,
causal=True
)
print(f"Masks shape: {masks.shape}")
print(f"Masked z shape: {masked_z.shape}")
print(f"Number of conditioned frames: {masks.sum()}")
# 텍스트 처리 예제
def text_processing_example():
prompts = [
"A beautiful sunset over the ocean",
"A cat playing in the garden"
]
# FPS 정보 추가
modified_prompts = add_fps_info_to_text(prompts, fps=24)
print("Modified prompts:", modified_prompts)
# 모션 스코어 추가
motion_prompts = add_motion_score_to_text(prompts, motion_score=5)
print("Motion prompts:", motion_prompts)
7.2 학습 파이프라인 예제
# 분산 학습 설정 예제
def training_setup_example():
# 디바이스 설정
device, coordinator = setup_device()
# ColossalAI 플러그인 생성
plugin = create_colossalai_plugin(
plugin="hybrid",
dtype="bf16",
grad_clip=1.0,
sp_size=2,
tp_size=2,
zero_stage=2,
)
# 가상 모델 생성
model = torch.nn.Linear(1024, 1024).to(device)
# 옵티마이저 설정
optimizer_config = {
"cls": "HybridAdam",
"lr": 1e-4,
"betas": (0.9, 0.95),
"weight_decay": 0.1,
}
optimizer = create_optimizer(model, optimizer_config)
# 스케줄러 설정
lr_scheduler = create_lr_scheduler(
optimizer=optimizer,
num_steps_per_epoch=1000,
epochs=100,
warmup_steps=1000,
use_cosine_scheduler=True,
)
return model, optimizer, lr_scheduler
# EMA 업데이트 예제
def ema_update_example():
# 메인 모델과 EMA 모델
main_model = torch.nn.Linear(1024, 512)
ema_model = torch.nn.Linear(1024, 512)
# EMA 모델 초기화 (메인 모델 가중치 복사)
ema_model.load_state_dict(main_model.state_dict())
# 학습 루프에서 EMA 업데이트
for step in range(100):
# ... 실제 학습 코드 ...
# EMA 업데이트 (매 스텝마다)
update_ema(
ema_model=ema_model,
model=main_model,
decay=0.9999,
sharded=False
)
if step % 10 == 0:
print(f"Step {step}: EMA updated")
7.3 체크포인트 관리 예제
# 체크포인트 저장/로드 예제
def checkpoint_example():
# 모델 생성
model = torch.nn.TransformerEncoder(
torch.nn.TransformerEncoderLayer(512, 8),
num_layers=6
)
# 체크포인트 로드 (다양한 형식 지원)
model = load_checkpoint(
model=model,
path="path/to/checkpoint.safetensors",
strict=False,
rename_keys={
"old_prefix": "new_prefix" # Fine-tuning 지원
}
)
# CheckpointIO를 사용한 고성능 저장
checkpoint_io = CheckpointIO()
# EMA 모델과 함께 저장
ema_model = copy.deepcopy(model)
ema_shape_dict = record_model_param_shape(ema_model)
save_path = checkpoint_io.save(
booster=booster, # ColossalAI Booster
save_dir="./checkpoints",
model=model,
ema=ema_model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=10,
step=1000,
global_step=10000,
batch_size=32,
ema_shape_dict=ema_shape_dict,
async_io=True,
include_master_weights=True,
)
print(f"Checkpoint saved to: {save_path}")
8. 성능 최적화 및 모니터링
8.1 메모리 사용량 추적
# 메모리 모니터링 예제
def memory_monitoring_example():
# 초기 메모리 상태
log_cuda_memory("initialization")
# 모델 로드
model = torch.nn.Linear(10000, 10000).cuda()
log_cuda_memory("after model load")
# 데이터 로드
data = torch.randn(1000, 10000).cuda()
log_cuda_memory("after data load")
# Forward pass
output = model(data)
log_cuda_memory("after forward")
# Backward pass
loss = output.sum()
loss.backward()
log_cuda_memory("after backward")
# 최대 메모리 사용량 로그
log_cuda_max_memory("training step")
# 모델 파라미터 수 로그
log_model_params(model)
8.2 성능 프로파일링
# 성능 측정 예제
def performance_profiling():
import time
model = torch.nn.Linear(1024, 1024).cuda()
data = torch.randn(32, 1024).cuda()
# 웜업
for _ in range(10):
_ = model(data)
torch.cuda.synchronize()
# 실제 측정
start_time = time.time()
torch.cuda.synchronize()
for _ in range(100):
output = model(data)
torch.cuda.synchronize()
end_time = time.time()
avg_time = (end_time - start_time) / 100
throughput = 32 / avg_time # 배치 크기 / 시간
print(f"Average inference time: {avg_time:.4f}s")
print(f"Throughput: {throughput:.2f} samples/s")
9. 한계점 및 개선 방향
9.1 현재 한계점
- 메모리 오버헤드: 다양한 조건부 생성으로 인한 메모리 사용량 증가
- I/O 병목: 대용량 체크포인트 저장/로드 시간
- 복잡성: 다양한 조건 타입으로 인한 코드 복잡도
- 디버깅: 분산 환경에서의 디버깅 어려움
9.2 개선 방향
# 미래 개선 방향 (예시)
class NextGenUtilities:
"""차세대 유틸리티 시스템"""
def __init__(self):
self.smart_memory_manager = SmartMemoryManager()
self.adaptive_checkpoint_io = AdaptiveCheckpointIO()
self.unified_condition_system = UnifiedConditionSystem()
def smart_memory_optimization(self):
"""지능형 메모리 최적화"""
# 동적 메모리 할당 및 해제
# 예측 기반 메모리 관리
pass
def compressed_checkpoint_io(self):
"""압축된 체크포인트 I/O"""
# 실시간 압축/압축해제
# 점진적 체크포인트 저장
pass
def unified_condition_handling(self):
"""통합된 조건 처리 시스템"""
# 단일 인터페이스로 모든 조건 타입 지원
# 자동 조건 최적화
pass
결론
Open-Sora의 유틸리티 모듈은 AI 비디오 생성 시스템의 핵심 기능을 지원하는 포괄적인 도구 모음입니다.
핵심 성과:
- 유연한 추론: 다양한 조건부 생성 모드 지원
- 효율적 학습: 분산 학습 및 EMA 업데이트 시스템
- 강력한 체크포인트: 비동기 I/O 및 다양한 형식 지원
- 메모리 관리: 실시간 모니터링 및 최적화 도구
이러한 유틸리티들은 Open-Sora가 대규모 비디오 생성 태스크를 안정적이고 효율적으로 수행할 수 있게 하는 핵심 인프라를 제공합니다. 앞으로 더욱 지능적이고 자동화된 시스템으로 발전하여 사용자 편의성과 성능을 동시에 향상시킬 것으로 기대됩니다.