주식 트레이드 강화 전략 추가
This commit is contained in:
@@ -138,13 +138,19 @@ class PricePredictor:
|
||||
else:
|
||||
print("[AI] No CUDA GPU detected. Running on CPU.")
|
||||
|
||||
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=0.0005, weight_decay=1e-4)
|
||||
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=0.001, weight_decay=1e-4)
|
||||
# [v2.0] Learning Rate Scheduler (ReduceLROnPlateau: val_loss 정체 시 lr 감소)
|
||||
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
self.optimizer, mode='min', factor=0.5, patience=7, min_lr=1e-6, verbose=False
|
||||
)
|
||||
self.scaler_amp = torch.amp.GradScaler('cuda') if self.use_amp else None
|
||||
|
||||
self.batch_size = 64
|
||||
self.max_epochs = 200
|
||||
self.seq_length = 60
|
||||
self.patience = 15
|
||||
# [v2.0] Gradient Clipping 값 (exploding gradient 방지)
|
||||
self.max_grad_norm = 1.0
|
||||
|
||||
self.training_status = {
|
||||
"is_training": False,
|
||||
@@ -237,12 +243,19 @@ class PricePredictor:
|
||||
max_epochs = 50 if has_checkpoint else self.max_epochs
|
||||
|
||||
# 4. 학습 (전체 데이터 GPU 상주, DataLoader 미사용)
|
||||
# [v2.0] LR Scheduler 리셋
|
||||
self.optimizer.param_groups[0]['lr'] = 0.001 if not has_checkpoint else 0.0005
|
||||
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
self.optimizer, mode='min', factor=0.5, patience=7, min_lr=1e-6, verbose=False
|
||||
)
|
||||
|
||||
self.model.train()
|
||||
self.training_status["is_training"] = True
|
||||
if ticker:
|
||||
self.training_status["current_ticker"] = ticker
|
||||
|
||||
best_val_loss = float('inf')
|
||||
best_model_state = None # [v2.0] Best Model 저장
|
||||
patience_counter = 0
|
||||
final_loss = 0.0
|
||||
actual_epochs = 0
|
||||
@@ -268,12 +281,17 @@ class PricePredictor:
|
||||
outputs = self.model(batch_x)
|
||||
loss = self.criterion(outputs, batch_y)
|
||||
self.scaler_amp.scale(loss).backward()
|
||||
# [v2.0] Gradient Clipping (AMP 호환)
|
||||
self.scaler_amp.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
||||
self.scaler_amp.step(self.optimizer)
|
||||
self.scaler_amp.update()
|
||||
else:
|
||||
outputs = self.model(batch_x)
|
||||
loss = self.criterion(outputs, batch_y)
|
||||
loss.backward()
|
||||
# [v2.0] Gradient Clipping
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
||||
self.optimizer.step()
|
||||
|
||||
epoch_loss += loss.item()
|
||||
@@ -293,17 +311,26 @@ class PricePredictor:
|
||||
val_loss = self.criterion(val_out, y_val).item()
|
||||
self.model.train()
|
||||
|
||||
# [v2.0] LR Scheduler step (val_loss 기반)
|
||||
self.lr_scheduler.step(val_loss)
|
||||
|
||||
final_loss = train_loss
|
||||
actual_epochs = epoch + 1
|
||||
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
patience_counter = 0
|
||||
# [v2.0] Best model 상태 저장 (overfitting 방지)
|
||||
best_model_state = {k: v.clone() for k, v in self.model.state_dict().items()}
|
||||
else:
|
||||
patience_counter += 1
|
||||
if patience_counter >= self.patience:
|
||||
break
|
||||
|
||||
# [v2.0] Best model 복원 (early stopping 후 최적 상태로 복구)
|
||||
if best_model_state:
|
||||
self.model.load_state_dict(best_model_state)
|
||||
|
||||
self.training_status["is_training"] = False
|
||||
self.training_status["loss"] = final_loss
|
||||
|
||||
@@ -346,7 +373,30 @@ class PricePredictor:
|
||||
current_price = prices[-1]
|
||||
trend = "UP" if predicted_price > current_price else "DOWN"
|
||||
change_rate = ((predicted_price - current_price) / current_price) * 100
|
||||
confidence = 1.0 / (1.0 + (final_loss * 100))
|
||||
|
||||
# [v2.0] 개선된 신뢰도 계산
|
||||
# 1. 학습 손실 기반 (낮을수록 좋음)
|
||||
loss_confidence = 1.0 / (1.0 + (best_val_loss * 50))
|
||||
|
||||
# 2. Train/Val 괴리도 (overfitting 감지)
|
||||
overfit_ratio = final_loss / (best_val_loss + 1e-9)
|
||||
if overfit_ratio < 0.5:
|
||||
# Train loss가 Val loss보다 훨씬 낮음 = overfitting
|
||||
overfit_penalty = 0.7
|
||||
elif overfit_ratio > 2.0:
|
||||
# Train loss가 Val loss보다 훨씬 높음 = underfitting
|
||||
overfit_penalty = 0.8
|
||||
else:
|
||||
overfit_penalty = 1.0
|
||||
|
||||
# 3. 에포크 수 기반 (너무 적거나 많으면 불신)
|
||||
epoch_factor = 1.0
|
||||
if actual_epochs < 10:
|
||||
epoch_factor = 0.6 # 학습 부족
|
||||
elif actual_epochs >= max_epochs:
|
||||
epoch_factor = 0.8 # 수렴 실패
|
||||
|
||||
confidence = min(0.95, loss_confidence * overfit_penalty * epoch_factor)
|
||||
|
||||
return {
|
||||
"current": current_price,
|
||||
@@ -354,9 +404,11 @@ class PricePredictor:
|
||||
"change_rate": round(change_rate, 2),
|
||||
"trend": trend,
|
||||
"loss": final_loss,
|
||||
"val_loss": best_val_loss,
|
||||
"confidence": round(confidence, 2),
|
||||
"epochs": actual_epochs,
|
||||
"device": str(self.device)
|
||||
"device": str(self.device),
|
||||
"lr": self.optimizer.param_groups[0]['lr']
|
||||
}
|
||||
|
||||
def batch_predict(self, prices_dict):
|
||||
|
||||
Reference in New Issue
Block a user