Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] : 딥러닝 가격 예측 조회 API 구현 (GET) #81

Merged
merged 11 commits into from
Jun 18, 2024
2 changes: 0 additions & 2 deletions backend/src/main/java/org/dgu/backend/BackendApplication.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.scheduling.annotation.EnableScheduling;

@SpringBootApplication
@EnableScheduling
public class BackendApplication {

public static void main(String[] args) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ public enum SuccessStatus implements BaseCode {
// Trading
SUCCESS_START_TRADING(HttpStatus.CREATED, "201", "자동매매 등록에 성공했습니다."),
SUCCESS_DELETE_TRADING(HttpStatus.OK, "200", "자동매매 삭제에 성공했습니다."),
SUCCESS_GET_TRADING_LOGS(HttpStatus.OK, "200", "자동매매 거래 로그 조회에 성공했습니다.");
SUCCESS_GET_TRADING_LOGS(HttpStatus.OK, "200", "자동매매 거래 로그 조회에 성공했습니다."),
// Prediction
SUCCESS_GET_PREDICTIONS(HttpStatus.OK, "200", "딥러닝 가격 예측 값 조회에 성공했습니다.");

private final HttpStatus httpStatus;
private final String code;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package org.dgu.backend.controller;

import lombok.RequiredArgsConstructor;
import org.dgu.backend.common.ApiResponse;
import org.dgu.backend.common.constant.SuccessStatus;
import org.dgu.backend.dto.PredictionDto;
import org.dgu.backend.service.PredictionDataScheduler;
import org.dgu.backend.service.PredictionService;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

import java.util.List;

@RestController
@RequestMapping("/api/v1/prediction")
@RequiredArgsConstructor
public class PredictionController {
private final PredictionService predictionService;
private final PredictionDataScheduler predictionDataScheduler;

// 딥러닝 가격 예측 값 조회 API
@GetMapping
public ResponseEntity<ApiResponse<List<PredictionDto.PredictionResponse>>> getPredictions() {

List<PredictionDto.PredictionResponse> predictionResponses = predictionService.getPredictions();
return ApiResponse.onSuccess(SuccessStatus.SUCCESS_GET_PREDICTIONS, predictionResponses);
}

// Train 수동 API
@GetMapping("/train")
public void startTrain() {

predictionDataScheduler.startTrain();
}

// 가격 예측 값 업데이트 수동 API
@GetMapping("/update")
public void getPrediction() {

predictionDataScheduler.getPrediction();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
@RequiredArgsConstructor
public class TradingController {
private final TradingService tradingService;
private final UpbitAutoTrader upbitAutoTrader;

// 자동매매 등록 API
@PostMapping
Expand All @@ -37,12 +38,18 @@ public ResponseEntity<ApiResponse<Object>> removeAutoTrading(
return ApiResponse.onSuccess(SuccessStatus.SUCCESS_DELETE_TRADING);
}

// 자동매매 수동 테스트 API
// 자동매매 거래 로그 조회 API
@GetMapping("/logs")
public ResponseEntity<ApiResponse<List<TradingDto.TradingLog>>> getUserTradingLogs(
@RequestHeader("Authorization") String authorizationHeader) {

List<TradingDto.TradingLog> tradingLogs = tradingService.getUserTradingLogs(authorizationHeader);
return ApiResponse.onSuccess(SuccessStatus.SUCCESS_GET_TRADING_LOGS, tradingLogs);
}

// 자동매매 수동 테스트 API
@GetMapping("/test")
public void startTrading() {
upbitAutoTrader.performAutoTrading();
}
}
33 changes: 33 additions & 0 deletions backend/src/main/java/org/dgu/backend/domain/Prediction.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.dgu.backend.domain;

import jakarta.persistence.*;
import lombok.*;
import org.dgu.backend.common.BaseEntity;

import java.time.Instant;
import java.time.LocalDateTime;
import java.time.ZoneOffset;

@Entity
@NoArgsConstructor(access = AccessLevel.PROTECTED)
@AllArgsConstructor(access = AccessLevel.PROTECTED)
@Builder
@Getter
@Table(name = "predictions")
public class Prediction extends BaseEntity {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
@Column(name = "predictions_id")
private Long id;

@Column(name = "date")
private LocalDateTime date;

@Column(name = "close")
private Long close;

public Prediction(String epochTime, Long close) {
this.date = LocalDateTime.ofInstant(Instant.ofEpochMilli(Long.parseLong(epochTime)), ZoneOffset.UTC);
this.close = (long) Math.round(close);
}
}
38 changes: 38 additions & 0 deletions backend/src/main/java/org/dgu/backend/dto/PredictionDto.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package org.dgu.backend.dto;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import org.dgu.backend.domain.Prediction;

import java.util.ArrayList;
import java.util.List;

public class PredictionDto {
@Builder
@Getter
@AllArgsConstructor
@JsonNaming(value = PropertyNamingStrategies.SnakeCaseStrategy.class)
@JsonInclude(JsonInclude.Include.NON_NULL)
public static class PredictionResponse {
@JsonProperty("date")
private String date;
@JsonProperty("close")
private Long close;

public static List<PredictionResponse> ofPredictions(List<Prediction> predictions) {
List<PredictionResponse> predictionResponses = new ArrayList<>();
for (Prediction prediction : predictions) {
predictionResponses.add(PredictionResponse.builder()
.date(String.valueOf(prediction.getDate()))
.close(prediction.getClose())
.build());
}
return predictionResponses;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import lombok.extern.slf4j.Slf4j;
import org.dgu.backend.common.ApiResponse;
import org.dgu.backend.common.code.BaseErrorCode;
import org.dgu.backend.domain.Market;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.MissingRequestHeaderException;
Expand Down Expand Up @@ -87,4 +86,10 @@ public ResponseEntity<ApiResponse<BaseErrorCode>> handleTradingException(Trading
TradingErrorResult errorResult = e.getTradingErrorResult();
return ApiResponse.onFailure(errorResult);
}
// Prediction
@ExceptionHandler(PredictionException.class)
public ResponseEntity<ApiResponse<BaseErrorCode>> handlePredictionException(PredictionException e) {
PredictionErrorResult errorResult = e.getPredictionErrorResult();
return ApiResponse.onFailure(errorResult);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package org.dgu.backend.exception;

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.dgu.backend.common.code.BaseErrorCode;
import org.dgu.backend.common.dto.ErrorReasonDto;
import org.springframework.http.HttpStatus;

@Getter
@RequiredArgsConstructor
public enum PredictionErrorResult implements BaseErrorCode {
FAIL_TO_TRAINING(HttpStatus.NOT_FOUND, "404", "딥러닝 트레이닝에 실패했습니다."),
FAIL_TO_PREDICTION(HttpStatus.NOT_FOUND, "404", "딥러닝 가격 예측 데이터 받아 오기에 실패했습니다."),
FAIL_TO_PARSE_RESPONSE(HttpStatus.NOT_FOUND, "404", "가격 예측 데이터 파싱에 실패했습니다");

private final HttpStatus httpStatus;
private final String code;
private final String message;

@Override
public ErrorReasonDto getReason() {
return ErrorReasonDto.builder()
.isSuccess(false)
.code(code)
.message(message)
.build();
}

@Override
public ErrorReasonDto getReasonHttpStatus() {
return ErrorReasonDto.builder()
.isSuccess(false)
.httpStatus(httpStatus)
.code(code)
.message(message)
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package org.dgu.backend.exception;

import lombok.Getter;
import lombok.RequiredArgsConstructor;

@Getter
@RequiredArgsConstructor
public class PredictionException extends RuntimeException {
private final PredictionErrorResult predictionErrorResult;

@Override
public String getMessage() {
return predictionErrorResult.getMessage();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package org.dgu.backend.repository;

import org.dgu.backend.domain.Prediction;
import org.springframework.data.jpa.repository.JpaRepository;

public interface PredictionRepository extends JpaRepository<Prediction,Long> {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package org.dgu.backend.service;

import com.fasterxml.jackson.databind.ObjectMapper;
import jakarta.transaction.Transactional;
import lombok.RequiredArgsConstructor;
import org.dgu.backend.domain.Prediction;
import org.dgu.backend.dto.ChartDto;
import org.dgu.backend.dto.PredictionDto;
import org.dgu.backend.exception.PredictionErrorResult;
import org.dgu.backend.exception.PredictionException;
import org.dgu.backend.repository.PredictionRepository;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpMethod;
import org.springframework.http.ResponseEntity;
import org.springframework.scheduling.annotation.EnableScheduling;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;

@Component
@RequiredArgsConstructor
@Transactional
@EnableScheduling
public class PredictionDataScheduler {
@Value("${ai.url.train}")
private String AI_URL_TRAIN;
@Value("${ai.url.predict}")
private String AI_URL_PREDICT;
private final ChartService chartService;
private final PredictionRepository predictionRepository;
private final RestTemplate restTemplate;
private final ObjectMapper objectMapper;

// Train 실행 메서드
//@Scheduled(cron = "0 0 0 * * *") // 매일 00:00에 실행
public void startTrain() {
List<ChartDto.OHLCVResponse> ohlcvResponses = chartService.getOHLCVCharts("비트코인", "days", null);
// Train 요청
ResponseEntity<String> trainResponseEntity = restTemplate.exchange(
AI_URL_TRAIN,
HttpMethod.POST,
new HttpEntity<>(ohlcvResponses),
String.class
);
String trainMessage = trainResponseEntity.getBody();
if (Objects.isNull(trainMessage)) {
throw new PredictionException(PredictionErrorResult.FAIL_TO_TRAINING);
}
System.out.println("Train Message: " + trainMessage);

}

// Prediction 값을 받아오는 메서드
//@Scheduled(cron = "0 10 0 * * *") // 매일 00:10에 실행
public void getPrediction() {
List<ChartDto.OHLCVResponse> ohlcvResponses = chartService.getOHLCVCharts("비트코인", "days", null);
// Prediction 요청
ResponseEntity<String> predictResponseEntity = restTemplate.exchange(
AI_URL_PREDICT,
HttpMethod.POST,
new HttpEntity<>(ohlcvResponses),
String.class
);
if (Objects.isNull(predictResponseEntity.getBody())) {
throw new PredictionException(PredictionErrorResult.FAIL_TO_PREDICTION);
}
String responseBody = predictResponseEntity.getBody();

// JSON 문자열을 PredictionDto 배열로 변환
PredictionDto.PredictionResponse[] predictions;
try {
predictions = objectMapper.readValue(responseBody, PredictionDto.PredictionResponse[].class);
} catch (IOException e) {
throw new PredictionException(PredictionErrorResult.FAIL_TO_PARSE_RESPONSE);
}

// 기존 값 제거
List<Prediction> existPredictions = predictionRepository.findAll();
if (!Objects.isNull(existPredictions)) {
predictionRepository.deleteAll(existPredictions);
predictionRepository.flush();
}
// 변환된 데이터를 엔티티로 저장
Arrays.stream(predictions)
.map(prediction -> new Prediction(prediction.getDate(), prediction.getClose()))
.forEach(predictionRepository::save);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package org.dgu.backend.service;

import org.dgu.backend.dto.PredictionDto;

import java.util.List;

public interface PredictionService {
List<PredictionDto.PredictionResponse> getPredictions();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package org.dgu.backend.service;

import jakarta.transaction.Transactional;
import lombok.RequiredArgsConstructor;
import org.dgu.backend.domain.Prediction;
import org.dgu.backend.dto.PredictionDto;
import org.dgu.backend.repository.PredictionRepository;
import org.springframework.stereotype.Service;

import java.util.List;

@Service
@Transactional
@RequiredArgsConstructor
public class PredictionServiceImpl implements PredictionService {
private final PredictionRepository predictionRepository;

// 딥러닝 가격 예측 값 반환 메서드
@Override
public List<PredictionDto.PredictionResponse> getPredictions() {
List<Prediction> predictions = predictionRepository.findAll();
return PredictionDto.PredictionResponse.ofPredictions(predictions);
}
}
Loading
Loading