# Import useful libraries import os import matplotlib.pyplot as plt import numpy as np import pandas as pd # Import stuff to setup TimeGPT from dotenv import load_dotenv from nixtla import NixtlaClient from statsmodels.tsa.statespace.sarimax import SARIMAX # Setup TimeGPT load_dotenv() NIXTLA_API_KEY = os.getenv("NIXTLA_API_KEY") nixtla_client = NixtlaClient( api_key=NIXTLA_API_KEY, ) nixtla_client.validate_api_key() # Define the SARIMA parameters order = (1, 1, 1) # (p, d, q) seasonal_order = (1, 1, 1, 12) # (P, D, Q, s) # Define the SARIMA parameters explicitly params = [ 0.5, 0.5, 0.5, # AR, I, MA parameters 0.5, 0.5, 0.5, # seasonal AR, seasonal I, seasonal MA parameters 0.1, ] # standard deviation of the error term # Generate a SARIMA model np.random.seed(42) n_samples = 200 sarima_model = SARIMAX([0], order=order, seasonal_order=seasonal_order) # Simulate the SARIMA time series simulated_data = sarima_model.simulate( params=params, nsimulations=n_samples, random_state=91 ) # Create a pandas Series with the simulated data date_range = pd.date_range(start="2000-01-01", periods=n_samples, freq="D") sarima_data = ( pd.Series(simulated_data, index=date_range, name="Actual") .rename_axis("date") .reset_index() ) test_start_date = pd.Timestamp("2000-05-15") # Forecast with the TimeGPT model train_sarima_data = sarima_data.loc[sarima_data["date"] < test_start_date] timegpt_forecast = nixtla_client.forecast( df=train_sarima_data, h=len(sarima_data) - len(train_sarima_data), freq="D", time_col="date", target_col="Actual", ) timegpt_forecast["date"] = pd.to_datetime(timegpt_forecast["date"]) sarima_data = sarima_data.merge(timegpt_forecast, on="date", how="left") # And with a fine-tuned TimeGPT model tuned_forecast = nixtla_client.forecast( df=train_sarima_data, h=len(sarima_data) - len(train_sarima_data), freq="D", time_col="date", finetune_steps=100, target_col="Actual", ).rename(columns={"TimeGPT": "TimeGPT (Fine-tuned)"}) tuned_forecast["date"] = pd.to_datetime(tuned_forecast["date"]) sarima_data = sarima_data.merge(tuned_forecast, on="date", how="left") # Plot the forecasted SARIMA time series fig, ax = plt.subplots(figsize=(10, 5)) sarima_data.plot( ax=ax, x="date", y=["Actual", "TimeGPT", "TimeGPT (Fine-tuned)"], title="SARIMA Forecast", xlabel="Date", ylabel="Value", ) ax.axvline(test_start_date, color="red", linestyle="--", label="Test Start Date") plt.legend() plt.show() sarima_data.to_csv("data/sarima_forecasts.csv", index=False)