(심화) 주식데이터들을 활용해 삼성주식 예측하기
필요 주식 데이터
-
kospi 지수
-
나스닥 지수 (미국 주식도 영향을 받기 때문)
-
이동 평균선
라이브러리 import
!pip install finance-datareader
!pip install requests_cache
Requirement already satisfied: finance-datareader in /usr/local/lib/python3.7/dist-packages (0.9.31) Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from finance-datareader) (4.62.3) Requirement already satisfied: pandas>=0.19.2 in /usr/local/lib/python3.7/dist-packages (from finance-datareader) (1.3.5) Requirement already satisfied: requests>=2.3.0 in /usr/local/lib/python3.7/dist-packages (from finance-datareader) (2.23.0) Requirement already satisfied: lxml in /usr/local/lib/python3.7/dist-packages (from finance-datareader) (4.2.6) Requirement already satisfied: requests-file in /usr/local/lib/python3.7/dist-packages (from finance-datareader) (1.5.1) Requirement already satisfied: numpy>=1.17.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.19.2->finance-datareader) (1.21.5) Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.19.2->finance-datareader) (2018.9) Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.19.2->finance-datareader) (2.8.2) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas>=0.19.2->finance-datareader) (1.15.0) Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.3.0->finance-datareader) (3.0.4) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.3.0->finance-datareader) (2021.10.8) Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.3.0->finance-datareader) (1.25.11) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.3.0->finance-datareader) (2.10) Requirement already satisfied: requests_cache in /usr/local/lib/python3.7/dist-packages (0.9.3) Requirement already satisfied: appdirs<2.0.0,>=1.4.4 in /usr/local/lib/python3.7/dist-packages (from requests_cache) (1.4.4) Requirement already satisfied: attrs<22.0,>=21.2 in /usr/local/lib/python3.7/dist-packages (from requests_cache) (21.4.0) Requirement already satisfied: cattrs<2.0,>=1.8 in /usr/local/lib/python3.7/dist-packages (from requests_cache) (1.10.0) Requirement already satisfied: urllib3<2.0.0,>=1.25.5 in /usr/local/lib/python3.7/dist-packages (from requests_cache) (1.25.11) Requirement already satisfied: url-normalize<2.0,>=1.4 in /usr/local/lib/python3.7/dist-packages (from requests_cache) (1.4.3) Requirement already satisfied: requests<3.0,>=2.22 in /usr/local/lib/python3.7/dist-packages (from requests_cache) (2.23.0) Requirement already satisfied: typing_extensions in /usr/local/lib/python3.7/dist-packages (from cattrs<2.0,>=1.8->requests_cache) (3.10.0.2) Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3.0,>=2.22->requests_cache) (3.0.4) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3.0,>=2.22->requests_cache) (2021.10.8) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3.0,>=2.22->requests_cache) (2.10) Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from url-normalize<2.0,>=1.4->requests_cache) (1.15.0)
import FinanceDataReader as fdr
import pandas_datareader as pdr
import requests_cache
import matplotlib.pyplot as plt
import pandas as pd
데이터 준비 및 분석
start_date = '20180101'
end_date = '20220225'
df = fdr.DataReader('005930', start_date, end_date) # 삼성전자: 005930
df
Open | High | Low | Close | Volume | Change | |
---|---|---|---|---|---|---|
Date | ||||||
2018-01-02 | 51380 | 51400 | 50780 | 51020 | 169485 | 0.001177 |
2018-01-03 | 52540 | 52560 | 51420 | 51620 | 200270 | 0.011760 |
2018-01-04 | 52120 | 52180 | 50640 | 51080 | 233909 | -0.010461 |
2018-01-05 | 51300 | 52120 | 51200 | 52120 | 189623 | 0.020360 |
2018-01-08 | 52400 | 52520 | 51500 | 52020 | 167673 | -0.001919 |
... | ... | ... | ... | ... | ... | ... |
2022-02-21 | 73200 | 74300 | 72600 | 74200 | 10489717 | -0.001346 |
2022-02-22 | 73000 | 73400 | 72800 | 73400 | 11692469 | -0.010782 |
2022-02-23 | 73800 | 73800 | 72800 | 73000 | 10397964 | -0.005450 |
2022-02-24 | 72300 | 72300 | 71300 | 71500 | 15759283 | -0.020548 |
2022-02-25 | 72100 | 72600 | 71900 | 71900 | 13062251 | 0.005594 |
1023 rows × 6 columns
1. kospi 지수
session = requests_cache.CachedSession(cache_name='cache', backend='sqlite')
# just add headers to your session and provide it to the reader
session.headers = {'User-Agent': 'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:89.0) Gecko/20100101 Firefox/89.0', 'Accept': 'application/json;charset=utf-8'}
kospi = pdr.DataReader('^KS11','yahoo', start_date, end_date, session=session)
kospi
High | Low | Open | Close | Volume | Adj Close | |
---|---|---|---|---|---|---|
Date | ||||||
2018-01-02 | 2481.020020 | 2465.939941 | 2474.860107 | 2479.649902 | 262200 | 2479.649902 |
2018-01-03 | 2493.399902 | 2481.909912 | 2484.629883 | 2486.350098 | 331100 | 2486.350098 |
2018-01-04 | 2502.500000 | 2466.449951 | 2502.500000 | 2466.459961 | 333800 | 2466.459961 |
2018-01-05 | 2497.520020 | 2475.510010 | 2476.850098 | 2497.520020 | 308800 | 2497.520020 |
2018-01-08 | 2515.370117 | 2494.179932 | 2510.699951 | 2513.280029 | 311400 | 2513.280029 |
... | ... | ... | ... | ... | ... | ... |
2022-02-21 | 2746.620117 | 2694.899902 | 2706.649902 | 2743.800049 | 495000 | 2743.800049 |
2022-02-22 | 2721.840088 | 2690.090088 | 2705.080078 | 2706.790039 | 648100 | 2706.790039 |
2022-02-23 | 2729.560059 | 2705.310059 | 2727.429932 | 2719.530029 | 537700 | 2719.530029 |
2022-02-24 | 2694.550049 | 2642.629883 | 2689.280029 | 2648.800049 | 925900 | 2648.800049 |
2022-02-25 | 2694.810059 | 2665.959961 | 2678.469971 | 2676.760010 | 664100 | 2676.760010 |
1022 rows × 6 columns
# 행 개수가 맞지 않으므로 제거
sum(df.index[:-1] == kospi.index)
986
df.drop('2022-01-03', axis=0, inplace=True)
df['kospi'] = kospi['Close'].values
df
Open | High | Low | Close | Volume | Change | kospi | |
---|---|---|---|---|---|---|---|
Date | |||||||
2018-01-02 | 51380 | 51400 | 50780 | 51020 | 169485 | 0.001177 | 2479.649902 |
2018-01-03 | 52540 | 52560 | 51420 | 51620 | 200270 | 0.011760 | 2486.350098 |
2018-01-04 | 52120 | 52180 | 50640 | 51080 | 233909 | -0.010461 | 2466.459961 |
2018-01-05 | 51300 | 52120 | 51200 | 52120 | 189623 | 0.020360 | 2497.520020 |
2018-01-08 | 52400 | 52520 | 51500 | 52020 | 167673 | -0.001919 | 2513.280029 |
... | ... | ... | ... | ... | ... | ... | ... |
2022-02-21 | 73200 | 74300 | 72600 | 74200 | 10489717 | -0.001346 | 2743.800049 |
2022-02-22 | 73000 | 73400 | 72800 | 73400 | 11692469 | -0.010782 | 2706.790039 |
2022-02-23 | 73800 | 73800 | 72800 | 73000 | 10397964 | -0.005450 | 2719.530029 |
2022-02-24 | 72300 | 72300 | 71300 | 71500 | 15759283 | -0.020548 | 2648.800049 |
2022-02-25 | 72100 | 72600 | 71900 | 71900 | 13062251 | 0.005594 | 2676.760010 |
1022 rows × 7 columns
2. 나스닥 지수
nasdaq = pdr.get_data_yahoo('^IXIC', start_date, end_date, session=session)
nasdaq
High | Low | Open | Close | Volume | Adj Close | |
---|---|---|---|---|---|---|
Date | ||||||
2018-01-02 | 7006.910156 | 6924.080078 | 6937.649902 | 7006.899902 | 1914930000 | 7006.899902 |
2018-01-03 | 7069.149902 | 7016.700195 | 7017.069824 | 7065.529785 | 2166780000 | 7065.529785 |
2018-01-04 | 7098.049805 | 7072.379883 | 7089.500000 | 7077.910156 | 2098890000 | 7077.910156 |
2018-01-05 | 7137.040039 | 7097.080078 | 7105.740234 | 7136.560059 | 2020900000 | 7136.560059 |
2018-01-08 | 7161.350098 | 7124.089844 | 7135.379883 | 7157.390137 | 2051430000 | 7157.390137 |
... | ... | ... | ... | ... | ... | ... |
2022-02-18 | 13762.400391 | 13465.559570 | 13735.400391 | 13548.070312 | 4475010000 | 13548.070312 |
2022-02-22 | 13618.719727 | 13249.650391 | 13424.360352 | 13381.519531 | 4830920000 | 13381.519531 |
2022-02-23 | 13533.780273 | 13032.169922 | 13511.750000 | 13037.490234 | 4614090000 | 13037.490234 |
2022-02-24 | 13486.110352 | 12587.879883 | 12587.879883 | 13473.589844 | 6131410000 | 13473.589844 |
2022-02-25 | 13696.860352 | 13358.290039 | 13485.259766 | 13694.620117 | 4614110000 | 13694.620117 |
1046 rows × 6 columns
nasdaq['nasdaq'] = nasdaq['Close']
nas = nasdaq[['Adj Close', 'nasdaq']]
df = nas.join(df)
df.drop('Adj Close', axis=1, inplace=True)
df.dropna(axis=0, inplace=True)
df
nasdaq | Open | High | Low | Close | Volume | Change | kospi | |
---|---|---|---|---|---|---|---|---|
Date | ||||||||
2018-01-02 | 7006.899902 | 51380.0 | 51400.0 | 50780.0 | 51020.0 | 169485.0 | 0.001177 | 2479.649902 |
2018-01-03 | 7065.529785 | 52540.0 | 52560.0 | 51420.0 | 51620.0 | 200270.0 | 0.011760 | 2486.350098 |
2018-01-04 | 7077.910156 | 52120.0 | 52180.0 | 50640.0 | 51080.0 | 233909.0 | -0.010461 | 2466.459961 |
2018-01-05 | 7136.560059 | 51300.0 | 52120.0 | 51200.0 | 52120.0 | 189623.0 | 0.020360 | 2497.520020 |
2018-01-08 | 7157.390137 | 52400.0 | 52520.0 | 51500.0 | 52020.0 | 167673.0 | -0.001919 | 2513.280029 |
... | ... | ... | ... | ... | ... | ... | ... | ... |
2022-02-18 | 13548.070312 | 74600.0 | 74800.0 | 73700.0 | 74300.0 | 10122226.0 | -0.009333 | 2744.520020 |
2022-02-22 | 13381.519531 | 73000.0 | 73400.0 | 72800.0 | 73400.0 | 11692469.0 | -0.010782 | 2706.790039 |
2022-02-23 | 13037.490234 | 73800.0 | 73800.0 | 72800.0 | 73000.0 | 10397964.0 | -0.005450 | 2719.530029 |
2022-02-24 | 13473.589844 | 72300.0 | 72300.0 | 71300.0 | 71500.0 | 15759283.0 | -0.020548 | 2648.800049 |
2022-02-25 | 13694.620117 | 72100.0 | 72600.0 | 71900.0 | 71900.0 | 13062251.0 | 0.005594 | 2676.760010 |
990 rows × 8 columns
3. 이동 평균션 (5, 20, 60)
주식 이동 방향성을 예측하기 위해 이동 평균선을 본다고 한다.
그래서 5일, 20일, 60일 이동 평균선을 추가한다.
이를 위해 rolling 함수 사용
# 삼성 주식
df['ma5'] = df['Close'].rolling(window=5).mean()
df['ma20'] = df['Close'].rolling(window=20).mean()
df['ma60'] = df['Close'].rolling(window=60).mean()
# 코스피
df['kospi_ma5'] = df['kospi'].rolling(window=5).mean()
df['kospi_ma20'] = df['kospi'].rolling(window=20).mean()
df['kospi_ma60'] = df['kospi'].rolling(window=60).mean()
# 나스닥
df['nasdaq_ma5'] = df['nasdaq'].rolling(window=5).mean()
df['nasdaq_ma20'] = df['nasdaq'].rolling(window=20).mean()
df['nasdaq_ma60'] = df['nasdaq'].rolling(window=60).mean()
df.dropna(inplace=True)
df
nasdaq | Open | High | Low | Close | Volume | Change | kospi | ma5 | ma20 | ma60 | kospi_ma5 | kospi_ma20 | kospi_ma60 | nasdaq_ma5 | nasdaq_ma20 | nasdaq_ma60 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Date | |||||||||||||||||
2018-04-03 | 6941.279785 | 47880.0 | 48140.0 | 47280.0 | 48120.0 | 255365.0 | -0.008653 | 2442.429932 | 48876.0 | 49979.0 | 49066.333333 | 2438.862012 | 2457.324512 | 2472.361169 | 6966.578027 | 7279.195020 | 7243.661987 |
2018-04-04 | 7042.109863 | 48160.0 | 48260.0 | 46920.0 | 46920.0 | 247684.0 | -0.024938 | 2408.060059 | 48264.0 | 49974.0 | 48998.000000 | 2430.062012 | 2457.157019 | 2471.168005 | 6973.237988 | 7262.700024 | 7244.248820 |
2018-04-05 | 7076.549805 | 47400.0 | 49380.0 | 47340.0 | 48740.0 | 264912.0 | 0.038789 | 2437.520020 | 48272.0 | 49980.0 | 48950.000000 | 2433.708008 | 2458.942017 | 2470.354171 | 6998.701953 | 7246.695020 | 7244.432487 |
2018-04-06 | 6915.109863 | 48000.0 | 48580.0 | 47400.0 | 48400.0 | 250654.0 | -0.006976 | 2429.580078 | 48144.0 | 49940.0 | 48905.333333 | 2432.350000 | 2458.767017 | 2469.739506 | 6969.033887 | 7221.053003 | 7241.719149 |
2018-04-09 | 6950.339844 | 48260.0 | 49440.0 | 48200.0 | 49200.0 | 199008.0 | 0.016529 | 2444.080078 | 48276.0 | 49913.0 | 48856.666667 | 2432.334033 | 2457.998523 | 2468.848840 | 6985.077832 | 7190.529492 | 7238.615479 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
2022-02-18 | 13548.070312 | 74600.0 | 74800.0 | 73700.0 | 74300.0 | 10122226.0 | -0.009333 | 2744.520020 | 74300.0 | 74285.0 | 75821.666667 | 2719.862012 | 2743.806006 | 2892.907992 | 13863.911914 | 13914.837988 | 14926.351660 |
2022-02-22 | 13381.519531 | 73000.0 | 73400.0 | 72800.0 | 73400.0 | 11692469.0 | -0.010782 | 2706.790039 | 74240.0 | 74140.0 | 75866.666667 | 2720.324023 | 2737.031506 | 2888.647493 | 13782.031836 | 13866.900977 | 14884.017480 |
2022-02-23 | 13037.490234 | 73800.0 | 73800.0 | 72800.0 | 73000.0 | 10397964.0 | -0.005450 | 2719.530029 | 74100.0 | 73965.0 | 75913.333333 | 2728.922021 | 2729.874011 | 2884.849996 | 13561.577930 | 13811.074512 | 14834.747152 |
2022-02-24 | 13473.589844 | 72300.0 | 72300.0 | 71300.0 | 71500.0 | 15759283.0 | -0.020548 | 2648.800049 | 73440.0 | 73760.0 | 75918.333333 | 2712.746045 | 2720.599512 | 2879.479663 | 13431.477930 | 13796.308008 | 14791.682975 |
2022-02-25 | 13694.620117 | 72100.0 | 72600.0 | 71900.0 | 71900.0 | 13062251.0 | 0.005594 | 2676.760010 | 72820.0 | 73600.0 | 75868.333333 | 2699.280029 | 2714.837512 | 2873.871497 | 13427.058008 | 13788.282520 | 14755.680648 |
931 rows × 17 columns
상관관계 분석
df.corr()
nasdaq | Open | High | Low | Close | Volume | Change | kospi | ma5 | ma20 | ma60 | kospi_ma5 | kospi_ma20 | kospi_ma60 | nasdaq_ma5 | nasdaq_ma20 | nasdaq_ma60 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
nasdaq | 1.000000 | 0.920460 | 0.918155 | 0.921570 | 0.930217 | 0.249382 | 0.004849 | 0.885036 | 0.930421 | 0.929457 | 0.928235 | 0.883946 | 0.877384 | 0.852673 | 0.998609 | 0.993128 | 0.981108 |
Open | 0.920460 | 1.000000 | 0.999120 | 0.999460 | 0.978007 | 0.324054 | -0.013461 | 0.896568 | 0.977900 | 0.969013 | 0.939894 | 0.894583 | 0.879538 | 0.829076 | 0.920682 | 0.917328 | 0.904685 |
High | 0.918155 | 0.999120 | 1.000000 | 0.999093 | 0.978724 | 0.341167 | 0.007400 | 0.894216 | 0.977584 | 0.968007 | 0.936988 | 0.891717 | 0.876309 | 0.824230 | 0.918345 | 0.915196 | 0.901917 |
Low | 0.921570 | 0.999460 | 0.999093 | 1.000000 | 0.978943 | 0.314104 | 0.001723 | 0.899391 | 0.978152 | 0.969074 | 0.940333 | 0.896813 | 0.881334 | 0.830768 | 0.921693 | 0.917962 | 0.905152 |
Close | 0.930217 | 0.978007 | 0.978724 | 0.978943 | 1.000000 | 0.313428 | 0.026625 | 0.921314 | 0.997249 | 0.985993 | 0.954659 | 0.917392 | 0.900758 | 0.848280 | 0.930282 | 0.926925 | 0.914300 |
Volume | 0.249382 | 0.324054 | 0.341167 | 0.314104 | 0.313428 | 1.000000 | 0.028491 | 0.124365 | 0.311064 | 0.306407 | 0.271780 | 0.125034 | 0.120924 | 0.072458 | 0.252054 | 0.262411 | 0.247607 |
Change | 0.004849 | -0.013461 | 0.007400 | 0.001723 | 0.026625 | 0.028491 | 1.000000 | 0.006820 | -0.025139 | -0.034819 | -0.036615 | -0.031905 | -0.041020 | -0.046557 | -0.008287 | -0.015503 | -0.018017 |
kospi | 0.885036 | 0.896568 | 0.894216 | 0.899391 | 0.921314 | 0.124365 | 0.006820 | 1.000000 | 0.921146 | 0.918214 | 0.910455 | 0.997411 | 0.985278 | 0.948553 | 0.884667 | 0.878097 | 0.863776 |
ma5 | 0.930421 | 0.977900 | 0.977584 | 0.978152 | 0.997249 | 0.311064 | -0.025139 | 0.921146 | 1.000000 | 0.991432 | 0.960779 | 0.921459 | 0.907117 | 0.855765 | 0.931799 | 0.929856 | 0.917830 |
ma20 | 0.929457 | 0.969013 | 0.968007 | 0.969074 | 0.985993 | 0.306407 | -0.034819 | 0.918214 | 0.991432 | 1.000000 | 0.979135 | 0.921126 | 0.921536 | 0.880267 | 0.931490 | 0.935086 | 0.927714 |
ma60 | 0.928235 | 0.939894 | 0.936988 | 0.940333 | 0.954659 | 0.271780 | -0.036615 | 0.910455 | 0.960779 | 0.979135 | 1.000000 | 0.913742 | 0.922310 | 0.918139 | 0.930058 | 0.935835 | 0.942810 |
kospi_ma5 | 0.883946 | 0.894583 | 0.891717 | 0.896813 | 0.917392 | 0.125034 | -0.031905 | 0.997411 | 0.921459 | 0.921126 | 0.913742 | 1.000000 | 0.990936 | 0.955392 | 0.885354 | 0.880910 | 0.866878 |
kospi_ma20 | 0.877384 | 0.879538 | 0.876309 | 0.881334 | 0.900758 | 0.120924 | -0.041020 | 0.985278 | 0.907117 | 0.921536 | 0.922310 | 0.990936 | 1.000000 | 0.976539 | 0.879932 | 0.884127 | 0.875298 |
kospi_ma60 | 0.852673 | 0.829076 | 0.824230 | 0.830768 | 0.848280 | 0.072458 | -0.046557 | 0.948553 | 0.855765 | 0.880267 | 0.918139 | 0.955392 | 0.976539 | 1.000000 | 0.855724 | 0.865486 | 0.877998 |
nasdaq_ma5 | 0.998609 | 0.920682 | 0.918345 | 0.921693 | 0.930282 | 0.252054 | -0.008287 | 0.884667 | 0.931799 | 0.931490 | 0.930058 | 0.885354 | 0.879932 | 0.855724 | 1.000000 | 0.995843 | 0.983974 |
nasdaq_ma20 | 0.993128 | 0.917328 | 0.915196 | 0.917962 | 0.926925 | 0.262411 | -0.015503 | 0.878097 | 0.929856 | 0.935086 | 0.935835 | 0.880910 | 0.884127 | 0.865486 | 0.995843 | 1.000000 | 0.991960 |
nasdaq_ma60 | 0.981108 | 0.904685 | 0.901917 | 0.905152 | 0.914300 | 0.247607 | -0.018017 | 0.863776 | 0.917830 | 0.927714 | 0.942810 | 0.866878 | 0.875298 | 0.877998 | 0.983974 | 0.991960 | 1.000000 |
import seaborn as sns
plt.figure(figsize=(12, 8))
sns.heatmap(df.corr(), linewidths=1, annot=True)
<matplotlib.axes._subplots.AxesSubplot at 0x7f54c7681790>
from sklearn.preprocessing import MinMaxScaler
import numpy as np
data = df[['Close', 'kospi', 'nasdaq', 'ma5', 'ma20', 'ma60', 'kospi_ma5', 'kospi_ma20', 'kospi_ma60', 'nasdaq_ma5', 'nasdaq_ma20', 'nasdaq_ma60']]
scaler = MinMaxScaler(feature_range=(0, 1))
scaled = scaler.fit_transform(data)
scaled
array([[0.19925303, 0.53301902, 0.07586379, ..., 0.05554597, 0.06195208, 0.02720884], [0.17684407, 0.51441627, 0.08608527, ..., 0.05624535, 0.06014227, 0.02727798], [0.210831 , 0.53036152, 0.08957657, ..., 0.0589194 , 0.05838623, 0.02729961], ..., [0.66386555, 0.68299986, 0.69385738, ..., 0.74810644, 0.77861905, 0.92152542], [0.63585434, 0.64471715, 0.73806628, ..., 0.73444425, 0.7769989 , 0.91645197], [0.643324 , 0.65985052, 0.76047287, ..., 0.7339801 , 0.77611835, 0.91221048]])
test_idx = int(len(scaled) * 0.8)
train = scaled[:test_idx]
test = scaled[test_idx:]
x_train = []
y_train = []
x_test = []
y_test = []
pasts = 15
for i in range(pasts, len(train)):
x_train.append(train[i-pasts:i, 0])
y_train.append(train[i, 0])
for i in range(pasts, len(test)):
x_test.append(test[i-pasts:i, 0])
y_test.append(test[i, 0])
x_train = np.array(x_train)
x_train = np.reshape(x_train, (x_train.shape[0], x_train.shape[1], 1))
y_train = np.array(y_train)
x_test = np.array(x_test)
x_test = np.reshape(x_test, (x_test.shape[0], x_test.shape[1], 1))
y_test = np.array(y_test)
x_train.shape, y_train.shape, x_test.shape, y_test.shape
((729, 15, 1), (729,), (172, 15, 1), (172,))
from sklearn.model_selection import train_test_split
train_x, val_x, train_y, val_y = train_test_split(x_train, y_train, test_size=0.1)
모델 생성
import keras
from keras.models import Sequential
from keras.layers import Dense
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.layers import LSTM
model = Sequential()
model.add(LSTM(50, input_shape=(x_train.shape[1], x_train.shape[2]), activation='relu', return_sequences=True))
model.add(LSTM(10, input_shape=(x_train.shape[1], x_train.shape[2]), activation='relu', return_sequences=False))
model.add(Dense(1))
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', patience=8)
model.compile(loss='mean_squared_error', optimizer='adam')
history = model.fit(train_x, train_y, epochs=100, batch_size=32, validation_data=(val_x, val_y), callbacks=[early_stop])
Epoch 1/100 21/21 [==============================] - 4s 66ms/step - loss: 0.1326 - val_loss: 0.1627 Epoch 2/100 21/21 [==============================] - 1s 38ms/step - loss: 0.0520 - val_loss: 0.0168 Epoch 3/100 21/21 [==============================] - 1s 33ms/step - loss: 0.0085 - val_loss: 0.0062 Epoch 4/100 21/21 [==============================] - 1s 31ms/step - loss: 0.0042 - val_loss: 0.0041 Epoch 5/100 21/21 [==============================] - 1s 28ms/step - loss: 0.0030 - val_loss: 0.0035 Epoch 6/100 21/21 [==============================] - 1s 27ms/step - loss: 0.0026 - val_loss: 0.0031 Epoch 7/100 21/21 [==============================] - 1s 30ms/step - loss: 0.0021 - val_loss: 0.0027 Epoch 8/100 21/21 [==============================] - 1s 37ms/step - loss: 0.0021 - val_loss: 0.0026 Epoch 9/100 21/21 [==============================] - 1s 30ms/step - loss: 0.0021 - val_loss: 0.0027 Epoch 10/100 21/21 [==============================] - 1s 33ms/step - loss: 0.0021 - val_loss: 0.0024 Epoch 11/100 21/21 [==============================] - 1s 26ms/step - loss: 0.0020 - val_loss: 0.0024 Epoch 12/100 21/21 [==============================] - 1s 28ms/step - loss: 0.0018 - val_loss: 0.0024 Epoch 13/100 21/21 [==============================] - 1s 31ms/step - loss: 0.0018 - val_loss: 0.0023 Epoch 14/100 21/21 [==============================] - 1s 29ms/step - loss: 0.0018 - val_loss: 0.0022 Epoch 15/100 21/21 [==============================] - 1s 33ms/step - loss: 0.0018 - val_loss: 0.0024 Epoch 16/100 21/21 [==============================] - 1s 32ms/step - loss: 0.0019 - val_loss: 0.0023 Epoch 17/100 21/21 [==============================] - 1s 26ms/step - loss: 0.0017 - val_loss: 0.0021 Epoch 18/100 21/21 [==============================] - 1s 30ms/step - loss: 0.0017 - val_loss: 0.0020 Epoch 19/100 21/21 [==============================] - 1s 25ms/step - loss: 0.0016 - val_loss: 0.0020 Epoch 20/100 21/21 [==============================] - 1s 29ms/step - loss: 0.0016 - val_loss: 0.0019 Epoch 21/100 21/21 [==============================] - 1s 30ms/step - loss: 0.0015 - val_loss: 0.0019 Epoch 22/100 21/21 [==============================] - 1s 28ms/step - loss: 0.0015 - val_loss: 0.0018 Epoch 23/100 21/21 [==============================] - 1s 39ms/step - loss: 0.0015 - val_loss: 0.0018 Epoch 24/100 21/21 [==============================] - 1s 34ms/step - loss: 0.0016 - val_loss: 0.0029 Epoch 25/100 21/21 [==============================] - 1s 32ms/step - loss: 0.0018 - val_loss: 0.0020 Epoch 26/100 21/21 [==============================] - 1s 38ms/step - loss: 0.0015 - val_loss: 0.0018 Epoch 27/100 21/21 [==============================] - 1s 35ms/step - loss: 0.0015 - val_loss: 0.0017 Epoch 28/100 21/21 [==============================] - 1s 27ms/step - loss: 0.0013 - val_loss: 0.0016 Epoch 29/100 21/21 [==============================] - 1s 25ms/step - loss: 0.0013 - val_loss: 0.0015 Epoch 30/100 21/21 [==============================] - 1s 35ms/step - loss: 0.0013 - val_loss: 0.0016 Epoch 31/100 21/21 [==============================] - 1s 27ms/step - loss: 0.0012 - val_loss: 0.0013 Epoch 32/100 21/21 [==============================] - 1s 33ms/step - loss: 0.0012 - val_loss: 0.0015 Epoch 33/100 21/21 [==============================] - 1s 35ms/step - loss: 0.0013 - val_loss: 0.0014 Epoch 34/100 21/21 [==============================] - 1s 30ms/step - loss: 0.0012 - val_loss: 0.0013 Epoch 35/100 21/21 [==============================] - 1s 32ms/step - loss: 0.0012 - val_loss: 0.0012 Epoch 36/100 21/21 [==============================] - 1s 27ms/step - loss: 0.0012 - val_loss: 0.0013 Epoch 37/100 21/21 [==============================] - 1s 29ms/step - loss: 0.0012 - val_loss: 0.0012 Epoch 38/100 21/21 [==============================] - 1s 36ms/step - loss: 0.0011 - val_loss: 0.0011 Epoch 39/100 21/21 [==============================] - 1s 35ms/step - loss: 0.0012 - val_loss: 0.0014 Epoch 40/100 21/21 [==============================] - 1s 28ms/step - loss: 0.0011 - val_loss: 0.0012 Epoch 41/100 21/21 [==============================] - 1s 29ms/step - loss: 0.0010 - val_loss: 0.0011 Epoch 42/100 21/21 [==============================] - 1s 30ms/step - loss: 0.0011 - val_loss: 0.0011 Epoch 43/100 21/21 [==============================] - 1s 29ms/step - loss: 9.7280e-04 - val_loss: 0.0015 Epoch 44/100 21/21 [==============================] - 1s 28ms/step - loss: 0.0011 - val_loss: 0.0013 Epoch 45/100 21/21 [==============================] - 1s 26ms/step - loss: 9.6635e-04 - val_loss: 0.0010 Epoch 46/100 21/21 [==============================] - 1s 24ms/step - loss: 9.9354e-04 - val_loss: 0.0011 Epoch 47/100 21/21 [==============================] - 1s 29ms/step - loss: 9.8607e-04 - val_loss: 0.0017 Epoch 48/100 21/21 [==============================] - 1s 24ms/step - loss: 0.0012 - val_loss: 0.0011 Epoch 49/100 21/21 [==============================] - 1s 25ms/step - loss: 0.0010 - val_loss: 9.8806e-04 Epoch 50/100 21/21 [==============================] - 1s 24ms/step - loss: 0.0010 - val_loss: 0.0010 Epoch 51/100 21/21 [==============================] - 1s 31ms/step - loss: 0.0011 - val_loss: 0.0010 Epoch 52/100 21/21 [==============================] - 1s 25ms/step - loss: 0.0011 - val_loss: 0.0010 Epoch 53/100 21/21 [==============================] - 1s 29ms/step - loss: 9.5178e-04 - val_loss: 0.0010 Epoch 54/100 21/21 [==============================] - 1s 24ms/step - loss: 9.1897e-04 - val_loss: 0.0011 Epoch 55/100 21/21 [==============================] - 0s 16ms/step - loss: 9.0765e-04 - val_loss: 9.5983e-04 Epoch 56/100 21/21 [==============================] - 0s 16ms/step - loss: 9.4123e-04 - val_loss: 9.6640e-04 Epoch 57/100 21/21 [==============================] - 0s 16ms/step - loss: 8.8457e-04 - val_loss: 9.2010e-04 Epoch 58/100 21/21 [==============================] - 0s 17ms/step - loss: 9.0110e-04 - val_loss: 0.0011 Epoch 59/100 21/21 [==============================] - 0s 16ms/step - loss: 9.0515e-04 - val_loss: 0.0015 Epoch 60/100 21/21 [==============================] - 0s 17ms/step - loss: 0.0011 - val_loss: 0.0011 Epoch 61/100 21/21 [==============================] - 0s 17ms/step - loss: 9.8273e-04 - val_loss: 0.0013 Epoch 62/100 21/21 [==============================] - 0s 16ms/step - loss: 9.4312e-04 - val_loss: 9.7527e-04 Epoch 63/100 21/21 [==============================] - 0s 17ms/step - loss: 9.4060e-04 - val_loss: 0.0012 Epoch 64/100 21/21 [==============================] - 0s 16ms/step - loss: 9.4910e-04 - val_loss: 0.0010 Epoch 65/100 21/21 [==============================] - 0s 17ms/step - loss: 8.6672e-04 - val_loss: 0.0010
plt.plot(history.history['loss'])
plt.show()
모델 예측
preds = model.predict(x_test)
plt.plot(preds, label='pred')
plt.plot(y_test, label='label')
plt.legend()
plt.show()
댓글남기기