Predicting Stocks Using LSTM
The term time series refers to a collection of data points ordered chronologically.
In this notebook, we will explore data from the stock market, especially some technology stocks (Apple, Amazon, Google, and Microsoft). Through the use of Seaborn and Matplotlib, we will be able to visualize various aspects of stock information using yfinance. Based on a stock’s previous performance, we will examine a few ways to assess its risk. A Long Short Term Memory (LSTM) method will also be used to predict future stock prices!
Along the way, we’ll answer the following questions:
1.) How did the stock’s price change over time?
2.) What was the average daily return of the stock?
3.) What was the moving average of each stock?
4.) What was the correlation between stocks?
5) What is the risk of investing in a particular stock?
6.) How can we predict future stock performance? LSTM-based prediction of Apple Inc’s closing stock price
Getting the Data
The first step is to get the data and load it to memory. We will get our stock data from the Yahoo Finance website. Yahoo Finance is a rich resource of financial market data and tools to find compelling investments. To get the data from Yahoo Finance, we will be using yfinance library which offers a threaded and Pythonic way to download market data from Yahoo.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')
plt.style.use("fivethirtyeight")
%matplotlib inline
# For reading stock data from yahoo
from pandas_datareader.data import DataReader
import yfinance as yf
from pandas_datareader import data as pdr
yf.pdr_override()
# For time stamps
from datetime import datetime
# The tech stocks we'll use for this analysis
tech_list = ['AAPL', 'GOOG', 'MSFT', 'AMZN']
# Set up End and Start times for data grab
tech_list = ['AAPL', 'GOOG', 'MSFT', 'AMZN']
end = datetime.now()
start = datetime(end.year - 1, end.month, end.day)
for stock in tech_list:
globals()[stock] = yf.download(stock, start, end)
company_list = [AAPL, GOOG, MSFT, AMZN]
company_name = ["APPLE", "GOOGLE", "MICROSOFT", "AMAZON"]
for company, com_name in zip(company_list, company_name):
company["company_name"] = com_name
df = pd.concat(company_list, axis=0)
df.tail(10)
In this code, several Python packages are imported, variables are defined, and Yahoo Finance stock data is downloaded. Each part of the code is explained briefly below:
Several Python libraries are imported in the first few lines, including pandas, numpy, matplotlib, and seaborn. For ease of use, they are usually imported with commonly used aliases (pd, np, plt, and sns).
Matplotlib plots are displayed in line within notebook cells with the %matplotlib inline line.
DataReader should be imported from pandas_datareader.data and yfinance should be imported as yf lines to download Yahoo Finance financial data.
By overriding the yf.pdr_override() line, the default pandas data reader is replaced with the one provided by yfinance.
Datetime is imported from the Python standard library using the from datetime import datetime line.
Data will be downloaded for the stock symbols listed in the tech_list variable.
Stock data will be downloaded for the time range defined by the start and end variables. This is one year’s worth of data up to the present.
For each stock in tech_list, data is downloaded using the yf.download function from the yfinance library and assigned to a variable with the same name as the stock symbol.
Lists of the downloaded stock data along with the names of the corresponding companies are contained in company_list and company_name variables.
In the second for loop, a new column called “company_name” is added to each DataFrame in company_list.
A dataframe called df is created by concatenating the data from all the dataframes in company_list using pd.concat. Concatenating data vertically is specified by the axis=0 parameter. To display the last 10 rows of the resulting DataFrame, use .tail(10).
Upon reviewing the content of our data, we can see that the data is numeric and the date is the index. Also note that weekends are not included in the records.
Setting the names of the DataFrames with globals() is a little sloppy, but it’s easy. Now that we have our data, let’s analyze it and check it.
Descriptive Statistics about the Data
The .describe() function generates descriptive statistics. These statistics summarize the central tendency, dispersion, and shape of a dataset’s distribution, excluding NaN
values.
A DataFrame column set of mixed data types can be analyzed, along with numeric and object series. Depending on the input, the output may differ. For more information, see the notes below.
# Summary Stats
AAPL.describe()
In this example, we are calling the describe() method on the DataFrame AAPL. This method generates summary statistics for each column of a DataFrame, such as count, mean, standard deviation, minimum, and maximum values.
Calling this method on AAPL will output a table with summary statistics for all columns in the AAPL DataFrame. You can use this to get a quick overview of the data and identify any potential issues or outliers.We have only 255 records in one year because weekends are not included in the data.
Information About the Data
.info()
method prints information about a DataFrame including the index dtype
and columns, non-null values, and memory usage.
# General info
AAPL.info()
The code below calls the info() method on the AAPL DataFrame. In addition to the number of rows and columns, column names, data types, and memory usage, the info() method provides general information about the DataFrame.
The code will output a summary of the DataFrame’s structure and content by calling this method on AAPL. It can be useful for understanding data types, identifying missing values, and identifying potential data quality issues. Also included in the output will be the memory usage, which can be useful for optimizing the DataFrame’s memory footprint.
Closing Price
A stock’s closing price is the price at which it was last traded during its regular trading day. To track the performance of a stock over time, investors use its closing price.
# Let's see a historical view of the closing price
plt.figure(figsize=(15, 10))
plt.subplots_adjust(top=1.25, bottom=1.2)
for i, company in enumerate(company_list, 1):
plt.subplot(2, 2, i)
company['Adj Close'].plot()
plt.ylabel('Adj Close')
plt.xlabel(None)
plt.title(f"Closing Price of {tech_list[i - 1]}")
plt.tight_layout()
In company_list, this code displays the historical closing price of each company in a figure with a 2x2 grid of subplots.
A new figure with a size of 15 by 10 inches is created using plt.figure(). To prevent overlap between subplots, plt.subplots_adjust() adjusts the spacing between them.
Using plt.subplot(), the for loop creates a new subplot in the figure for each company in company_list. Plot.subplot() specifies a subplot’s position in the grid via the i parameter. A subplot is created based on the adjusted closing price of the company through the company[‘Adj Close’].plot() function.
A y-axis label is set to “Adj Close” using the plt.ylabel() function. With plt.xlabel(None), the x-axis label is removed, and with plt.title(), the subplot title is “Closing Price of [company symbol]”.
It enhances readability by adjusting the spacing between the subplots using plt.tight_layout().
It’s easy to compare and contrast historical closing prices of each company in the company_list using this code, providing a clear and concise way to visualize historical closing prices.