Source code for optrade.data.forecasting

from copy import deepcopy
from pathlib import Path
from typing import Tuple, List, Union, Optional
from rich.console import Console
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
import torch
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from datetime import datetime, timedelta
import pandas_market_calendars as mcal

# Datasets
from optrade.data.thetadata import find_optimal_strike, get_expirations
from optrade.data.contracts import Contract, ContractDataset
from optrade.utils.error_handlers import (
    DataValidationError,
    INCOMPATIBLE_START_DATE,
    INCOMPATIBLE_END_DATE,
    NAN_FEATURES,
)

# Features
from optrade.data.features import transform_features
from optrade.utils.misc import datetime_to_tensor, tensor_to_datetime

# Get absolute path for this script
SCRIPT_DIR = Path(__file__).resolve().parent


[docs] class ForecastingDataset(Dataset):
[docs] def __init__( self, data: pd.DataFrame, seq_len: int, pred_len: int, target_channels: Optional[List[str]] = None, target_type: str = "multistep", dtype: str = "float32", normalize_target: bool = False, ) -> None: """ Initializes the ForecastingDataset class. Args: data (pd.DataFrame): Input DataFrame containing the time series data. seq_len (int): Length of the lookback window for each sample. pred_len (int): Length of the forecast window (number of steps ahead to predict). target_channels (Optional[List[str]]): List of column names to include as target channels. If None, all columns are used. target_type (str): Type of target to predict. Must be one of: - "multistep": Predicts the full future sequence (regression). - "average": Predicts the average value over the forecast window (regression). - "average_direction": Predicts the sign of the average change (binary classification). dtype (str): Data type for the internal PyTorch tensors (e.g., "float32", "float64"). Default is "float32". normalize_target (bool): Whether to apply normalization to the target variable(s). Returns: None """ # Get features names from the DataFrame self.has_datetime = "datetime" in data.columns self.feature_names = ( data.drop(columns=["datetime"]).columns.to_list() if self.has_datetime else data.columns.to_list() ) # Numeric data conversion (remove datetime strings from direct features) if self.has_datetime: self.datetime = data["datetime"].values # Store as numpy array data_numeric = data.drop(columns=["datetime"]).to_numpy() else: data_numeric = data.to_numpy() # Check for NaNs using the original DataFrame (with column names) columns = ( data.drop(columns=["datetime"]).columns if self.has_datetime else data.columns ) nan_counts = data[columns].isna().sum() if nan_counts.any(): bad_feats = nan_counts[nan_counts > 0] message = ( f"[ForecastingDataset] Found NaNs in {len(bad_feats)} feature(s):\n" f"{bad_feats.to_string()}\n" "This is often caused by illiquid instruments (e.g., no volume or quote data) " "or unstable features like LOB imbalance or spread when interval_min is too low.\n" "Consider dropping these features or resampling to a longer interval by increasing interval_min parmaeter." ) raise DataValidationError( message=message, error_code=NAN_FEATURES, verbose=True, warning=False ) self.torch_dtype = eval("torch." + dtype) self.np_dtype = eval("np." + dtype) self.data = torch.tensor(data_numeric, dtype=self.torch_dtype) self.seq_len = seq_len self.pred_len = pred_len self.target_type = target_type self.normalize_target = normalize_target # Clone target data and get indices for target channels if provided self.target_data = self.data.clone() if target_channels is not None and len(target_channels) > 0: self.target_channels_idx = [ self.feature_names.index(channel) for channel in target_channels ]
def __len__(self) -> int: """ Returns the number of input-target pairs in the dataset. """ return self.data.shape[0] - self.seq_len - self.pred_len def __getitem__(self, idx: int) -> Union[ Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], ]: """Get a sample from the dataset. This method retrieves an input-target pair at the specified index, with input being the lookback window and target being the forecast window based on the target_type. Args: idx: Index of the starting point of the lookback window. Returns: If datetime is available: tuple: A tuple containing (input_tensor, target_tensor, input_datetime, target_datetime) - input_tensor: Lookback window of shape (num_features, seq_len). - target_tensor: Target window with shape depending on target_type: - "multistep": (num_target_features, pred_len) - "average": (num_target_features, 1) - "average_direction": (num_target_features, 1) - input_datetime: Datetime values for input window of shape (seq_len,). - target_datetime: Datetime values for target window of shape (pred_len,). Otherwise: tuple: A tuple containing (input_tensor, target_tensor) - input_tensor: Lookback window of shape (num_features, seq_len). - target_tensor: Target window with shape as described above depending on `target_type` arg. """ input = self.data[idx : idx + self.seq_len] if hasattr(self, "target_channels_idx"): target = self.target_data[ idx + self.seq_len : idx + self.seq_len + self.pred_len, self.target_channels_idx, ] else: target = self.target_data[idx + self.seq_len : idx + self.seq_len + self.pred_len] input_tensor = input.transpose(0, 1) target_tensor = target.transpose(0, 1) if self.target_type == "average": target_tensor = target_tensor.mean(dim=1).unsqueeze(0) elif self.target_type == "average_direction": target_tensor = (target_tensor.mean(dim=1) > 0).unsqueeze(0).float() elif self.target_type == "triple_barrier": raise NotImplementedError elif self.target_type == "multistep": pass else: raise ValueError( "Invalid target_type. Options: 'multistep', 'average', or 'average_direction'." ) if self.has_datetime: input_datetime = self.datetime[idx : idx + self.seq_len] target_datetime = self.datetime[ idx + self.seq_len : idx + self.seq_len + self.pred_len ] # Convert datetime arrays to tensors input_datetime_tensor = datetime_to_tensor(input_datetime) target_datetime_tensor = datetime_to_tensor(target_datetime) return ( input_tensor, target_tensor, input_datetime_tensor, target_datetime_tensor, ) else: return input_tensor, target_tensor
[docs] def to_numpy(self) -> Union[ Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], ]: """Converts the dataset into a set of NumPy arrays for scikit-learn model training. Returns: Tuple[np.ndarray, np.ndarray]: A tuple containing: - inputs: NumPy array of shape (num_samples, seq_len, num_features). - targets: NumPy array of shape (num_samples, pred_len, num_target_features). If datetime is available: - input_datetimes: NumPy array of shape (num_samples, seq_len). - target_datetimes: NumPy array of shape (num_samples, pred_len). """ num_target_features = len(self.target_channels_idx) if hasattr(self, "target_channels_idx") else self.data.shape[1] if self.target_type == "multistep": target_dtype = self.np_dtype target_shape = (len(self), self.pred_len, num_target_features) elif self.target_type == "average": target_dtype = self.np_dtype target_shape = (len(self), 1, num_target_features) elif self.target_type == "average_direction": target_dtype = np.float32 target_shape = (len(self), 1, num_target_features) elif self.target_type == "triple_barrier": raise NotImplementedError else: raise ValueError( "Invalid target_type. Options: 'multistep', 'average', or 'average_direction'." ) inputs = np.empty( (len(self), self.seq_len, self.data.shape[1]), dtype=self.np_dtype ) targets = np.empty(target_shape, dtype=target_dtype) if self.has_datetime: # Specify the unit here to match what tensor_to_datetime uses unit = "s" # seconds # Create the arrays with explicit time unit input_datetimes = np.empty( (len(self), self.seq_len), dtype=f"datetime64[{unit}]" ) target_datetimes = np.empty( (len(self), self.pred_len), dtype=f"datetime64[{unit}]" ) for i in range(len(self)): item = self.__getitem__(i) input_tensor, target_tensor = item[0], item[1] inputs[i] = input_tensor.numpy().T targets[i] = target_tensor.numpy().T if self.has_datetime: input_datetime_tensor, target_datetime_tensor = item[2], item[3] input_datetimes[i] = tensor_to_datetime( input_datetime_tensor, batch_mode=False ) target_datetimes[i] = tensor_to_datetime( target_datetime_tensor, batch_mode=False ) if self.has_datetime: return inputs, targets, input_datetimes, target_datetimes else: return inputs, targets
[docs] def get_item(self, idx: int) -> Union[ Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, np.ndarray, np.ndarray], ]: """Get a sample from the dataset. This method retrieves an input-target pair at the specified index, with input being the lookback window and target being the forecast window based on the target_type. Args: idx: Index of the starting point of the lookback window. Returns: If datetime is available: tuple: A tuple containing (input_tensor, target_tensor, input_datetime, target_datetime) - input_tensor: Lookback window of shape (num_features, seq_len). - target_tensor: Target window with shape depending on target_type: - "multistep": (num_target_features, pred_len) - "average": (num_target_features, 1) - "average_direction": (num_target_features, 1) - input_datetime: Datetime values for input window of shape (seq_len,). - target_datetime: Datetime values for target window of shape (pred_len,). Otherwise: tuple: A tuple containing (input_tensor, target_tensor) - input_tensor: Lookback window of shape (num_features, seq_len). - target_tensor: Target window with shape as described above. """ return self.__getitem__(idx)
[docs] def normalize_concat_dataset( concat_dataset: ConcatDataset, scaler: StandardScaler, ) -> None: """ Modifies the data in a ConcatDataset in-place by normalizing it using a fitted StandardScaler. Args: concat_dataset: ConcatDataset object containing ForecastingDatasets scaler: Fitted StandardScaler from scikit-learn. Returns: None """ for dataset in concat_dataset.datasets: data = dataset.data.numpy() # Normalize data normalized_data = scaler.transform(data) # Replace data with normalized version dataset.data = torch.tensor(normalized_data, dtype=dataset.torch_dtype) if dataset.normalize_target: target_data = dataset.target_data.numpy() normalized_target_data = scaler.transform(target_data) dataset.target_data = torch.tensor(normalized_target_data, dtype=dataset.torch_dtype)
[docs] def normalize_datasets( train_dataset: ConcatDataset, val_dataset: ConcatDataset, test_dataset: ConcatDataset, ) -> Tuple[ConcatDataset, ConcatDataset, ConcatDataset, StandardScaler]: """ Normalizes financial time series datasets using StandardScaler. Fits scaler only on training data to prevent look-ahead bias. Args: train_dataset: Training dataset (ConcatDataset of ForecastingDatasets) val_dataset: Validation dataset test_dataset: Test dataset Returns: Tuple[ConcatDataset, ConcatDataset, ConcatDataset, StandardScaler]: Normalized training, validation, and test datasets, and the fitted Standard """ # Extract all underlying data from train_dataset all_train_data = [] # Iterate through the individual ForecastingDataset objects within the ConcatDataset for dataset_idx in range(len(train_dataset.datasets)): individual_dataset = train_dataset.datasets[dataset_idx] # Extract data tensor from the individual dataset data = individual_dataset.data.numpy() all_train_data.append(data) # Stack all the arrays together combined_train_data = np.vstack(all_train_data) # Fit scaler on training data only scaler = StandardScaler() scaler.fit(combined_train_data) # Apply normalization to all datasets normalize_concat_dataset(train_dataset, scaler) normalize_concat_dataset(val_dataset, scaler) normalize_concat_dataset(test_dataset, scaler) return train_dataset, val_dataset, test_dataset, scaler
[docs] def get_forecasting_dataset( contract_dataset: ContractDataset, tte_tolerance: Tuple[int, int], seq_len: Optional[int] = None, pred_len: Optional[int] = None, core_feats: List[str] = ["option_returns"], tte_feats: Optional[List[str]] = None, datetime_feats: Optional[List[str]] = None, vol_feats: Optional[List[str]] = None, rolling_volatility_range: Optional[List[int]] = None, keep_datetime: bool = False, target_type: str = "multistep", clean_up: bool = False, offline: bool = False, intraday: bool = False, target_channels: Optional[List[str]] = None, dtype: str = "float32", normalize_target: bool = False, save_dir: Optional[str] = None, download_only: bool = False, validate_contracts: bool = False, modify_contracts: bool = False, verbose: bool = False, warning: bool = True, dev_mode: bool = False, ) -> Union[ContractDataset, Tuple[ConcatDataset, ContractDataset]]: """ Creates a PyTorch dataset object composed of multiple ForecastingDatasets, each representing different option contracts. Args: contract_dataset: ContractDataset object containing option contract parameters tte_tolerance: Tuple of (min, max) time to expiration tolerance in days core_feats: List of core features to include tte_feats: List of time-to-expiration features to include datetime_feats: List of datetime features to include vol_feats: List of volatility features to include rolling_volatility_range: List of rolling volatility ranges to include keep_datetime: Whether to keep the datetime column in the dataset target_type: Type of forecasting target. Options: "multistep" (float), "average" (float), or "average_direction" (binary). clean_up: Whether to clean up the data after use offline: Whether to load saved contracts from disk intraday: Whether to use intraday data target_channels: List of target channels to include in the target tensor. If None, all channels will be included. seq_len: Sequence length of lookback window (input) pred_len: Prediction length of forecast window (target) dtype: Data type for the PyTorch tensors normalize_target: Whether to normalize the target variable(s) save_dir: Save directory download_only: Whether to download data only (used mainly for Universe class) validate_contracts: Whether to validate contracts by requesting data from ThetaData API and adjustintg start and end dates if necessary. modify_contracts: Whether to delete old contracts .pkl file and save the (new) validate contracts in the same path. Warning: This will overwrite the old contracts. verbose: Whether to print verbose output warning: Whether to print verbose DataValidationError statements as warnings or errors. dev_mode: Whether to run in development mode. Returns: ContractDataset: The updated ContractDataset object if `download_only`=True or `validate_contracts`=True. Tuple[ConcatDataset, ContractDataset]: A tuple containing the concatenated PyTorch dataset and the updated ContractDataset if download_only=False. """ ctx = Console() dataset_list = [] validated_contracts = [] # Track which contracts are actually valid updated_contracts = [] # Track contracts that were updated (not in initial list) expirations_exist = False tried_contracts = set() # First, validate that download_only and validate_contracts aren't both True if download_only and validate_contracts: raise ValueError( "Please use download_only and validate_contracts separately. Both cannot be True." ) if download_only: clean_up = False offline = False elif validate_contracts: clean_up = True offline = False else: assert seq_len is not None, "seq_len must be provided for forecasting dataset" assert pred_len is not None, "pred_len must be provided for forecasting dataset" # Save initial contracts to compare adjusted contracts later initial_contracts = deepcopy(contract_dataset.contracts) # Iterate through each contract in the ContractDataset for ( contract ) in ( initial_contracts ): # Use initial_contracts instead to avoid modifying during iteration # Flag to track if we should try the next contract move_to_next_contract = False original_contract = deepcopy(contract) # Keep original for comparison # Continue trying the current contract until we succeed or decide to move on while not move_to_next_contract: try: df = contract.load_data( save_dir=save_dir, clean_up=clean_up, offline=offline, warning=warning, dev_mode=dev_mode, ) if not (download_only or validate_contracts): # Select and add features data = transform_features( df=df, core_feats=core_feats, tte_feats=tte_feats, datetime_feats=datetime_feats, vol_feats=vol_feats, rolling_volatility_range=rolling_volatility_range, root=contract.root, right=contract.right, strike=contract.strike, exp=contract.exp, keep_datetime=keep_datetime, ) # Convert to PyTorch dataset dataset = ForecastingDataset( data=data, seq_len=seq_len, pred_len=pred_len, target_channels=target_channels, target_type=target_type, dtype=dtype, normalize_target=normalize_target, ) dataset_list.append(dataset) # If contract was modified from its original state, add to updated contracts if contract != original_contract: updated_contracts.append(contract) # Add valid contract to validated list validated_contracts.append(contract) move_to_next_contract = True except DataValidationError as e: # Handle date-related errors that might be recoverable if e.error_code in [INCOMPATIBLE_START_DATE, INCOMPATIBLE_END_DATE]: candidate_start_date = e.real_start_date candidate_exp = ( e.real_end_date if e.error_code == INCOMPATIBLE_END_DATE else contract.exp ) move_to_next_contract, new_contract = calibrate_new_contract( contract_dataset=contract_dataset, original_contract=original_contract, candidate_start_date=candidate_start_date, candidate_exp=candidate_exp, tte_tolerance=tte_tolerance, expirations_exist=expirations_exist, save_dir=save_dir, verbose=verbose, dev_mode=dev_mode, ) expirations_exist = True # If calibration suggests retrying with a new contract... if not move_to_next_contract and new_contract is not None: contract_key = (new_contract.start_date, new_contract.exp) if contract_key in tried_contracts: if verbose: ctx.log(f"Previously tried contract {contract_key} failed. Skipping.") move_to_next_contract = True # Skip this contract continue tried_contracts.add(contract_key) contract = new_contract # Try the calibrated contract elif e.error_code == NAN_FEATURES: raise e # Raise the error to be handled outside else: if verbose: ctx.log( f"DataValidationError for {contract}: {e}. Moving to next contract." ) move_to_next_contract = True except Exception as e: if verbose: ctx.log( f"Unknown error for {contract}: {e}. Moving to next contract." ) raise e # move_to_next_contract = True # Find removed contracts (in initial but not in validated) removed_contracts = [c for c in initial_contracts if c not in validated_contracts] # Find truly new contracts (in updated but not equivalent to any in initial) # This is different from just comparing sets because contracts might be updated versions new_contracts = [c for c in updated_contracts] # Display changes if verbose if verbose and (removed_contracts or new_contracts): from rich.table import Table # Create a table for removed contracts if removed_contracts: removed_table = Table(title="Invalid Contracts", show_header=True) removed_table.add_column("Root", style="cyan") removed_table.add_column("Start Date", style="cyan") removed_table.add_column("Expiration", style="cyan") removed_table.add_column("Strike", style="cyan") removed_table.add_column("Right", style="cyan") for contract in removed_contracts: removed_table.add_row( contract.root, contract.start_date, contract.exp, str(contract.strike), contract.right, style="red", ) ctx.print(removed_table) # Create a table for updated/new contracts if updated_contracts: updated_table = Table(title="Updated (New Contracts)", show_header=True) updated_table.add_column("Root", style="cyan") updated_table.add_column("Start Date", style="cyan") updated_table.add_column("Expiration", style="cyan") updated_table.add_column("Strike", style="cyan") updated_table.add_column("Right", style="cyan") updated_table.add_column("Status", style="cyan") for contract in updated_contracts: # Check if this is a modified version of an original contract is_update = False for orig in initial_contracts: if ( contract.root == orig.root and contract.strike == orig.strike and contract.interval_min == orig.interval_min and contract.right == orig.right ): is_update = True break status = "Updated" if is_update else "New" updated_table.add_row( contract.root, contract.start_date, contract.exp, str(contract.strike), contract.right, status, style="green", ) ctx.print(updated_table) # Replace the contract list with only validated contracts contract_dataset.contracts = validated_contracts # Save contract changes if needed contracts_changed = set(validated_contracts) != set(initial_contracts) if contracts_changed and modify_contracts: if verbose: ctx.log( f"Contract changes detected: {len(removed_contracts)} removed, {len(updated_contracts)} updated/added" ) ctx.log("Saving and overwriting ContractDataset .pkl file") contract_dataset.save(clean_file=True) if download_only or validate_contracts: return contract_dataset else: if len(dataset_list) == 0: raise ValueError( "No valid contracts found. All contracts were invalid or contained errors." ) return ConcatDataset(dataset_list), contract_dataset
[docs] def calibrate_new_contract( contract_dataset: ContractDataset, original_contract: Contract, candidate_start_date: str, candidate_exp: str, tte_tolerance: Tuple[int, int], expirations_exist: bool = False, save_dir: Optional[str] = None, verbose: bool = False, dev_mode: bool = False, ) -> Tuple[bool, Optional[Contract]]: # Rich console ctx = Console() # <------Calibrate New Expiration------> if verbose: ctx.log( f"Calibrating new expiration date '{candidate_exp}' for original contract..." ) # Query list of expirations to find closests expiration to real_end_date expirations_offline = ( True if expirations_exist else False ) # Do not query API if expirations.csv exists expirations = get_expirations( root=original_contract.root, save_dir=save_dir, clean_up=False, offline=expirations_offline, dev_mode=dev_mode, ) expirations["date"] = pd.to_datetime( expirations["date"].astype(str), format="%Y%m%d" ) # Ensure dates are in datetime format candidate_exp = pd.to_datetime( candidate_exp, format="%Y%m%d" ) # Ensure candidate expiration is in datetime format closest_date = (expirations["date"] - candidate_exp).abs().idxmin() new_exp = expirations.loc[closest_date, "date"].strftime("%Y%m%d") if verbose: ctx.log(f"Closest expiration date found: {new_exp}") # <------Calibrate New Start Date------> if verbose: ctx.log( f"Calibrating new start date '{candidate_start_date}' for original contract..." ) # Validate that the new start date doesn't fall on a weekend or market holiday, otherwise move forward in time to next valid date new_start_date = get_valid_start_date(candidate_start_date) if verbose: ctx.log(f"New start date found: {new_start_date}") # <------TTE Tolerance Check------> # Check if the timespan from new_start_date to new_exp is within time-to-expiration tolerance start_date_dt = pd.to_datetime(new_start_date, format="%Y%m%d") exp_date_dt = pd.to_datetime(new_exp, format="%Y%m%d") time_span = exp_date_dt - start_date_dt if time_span >= pd.Timedelta(days=tte_tolerance[0]): # Update the contract if it meets the tolerance if verbose: ctx.log( f"TTE Tolerance = {time_span} check passed: updating contract with new start date: {new_start_date} and expiration: {new_exp}" ) # Create a new contract and modify parameters contract = deepcopy(original_contract) contract.start_date = new_start_date contract.exp = new_exp # Update strike based on new parameters contract.strike = find_optimal_strike( root=contract.root, start_date=new_start_date, exp=new_exp, right=contract_dataset.right, interval_min=contract_dataset.interval_min, moneyness=contract_dataset.moneyness, strike_band=contract_dataset.strike_band, volatility_scaled=contract_dataset.volatility_scaled, volatility_scalar=contract_dataset.volatility_scalar, hist_vol=contract_dataset.hist_vol, clean_up=True, ) return False, contract else: if verbose: ctx.log( f"Contract timespan ({time_span.days} days) is too short. Moving to next contract." ) return True, None
[docs] def get_valid_start_date(candidate_start_date: str) -> str: """ Return the next valid NYSE trading day given a candidate date in YYYYMMDD format. This function checks whether the provided date falls on a weekend or a NYSE holiday. If so, it advances the date forward to the next valid trading day. Args: candidate_start_date (str): The date to validate, in 'YYYYMMDD' format. Returns: str: The next valid NYSE trading day in 'YYYYMMDD' format. Raises: ValueError: If no valid trading day is found within the search buffer. """ nyse = mcal.get_calendar("NYSE") date_obj = datetime.strptime(candidate_start_date, "%Y%m%d") # Get 10 future trading days (timezone-aware, but we’ll just use .date()) trading_days = nyse.valid_days( start_date=candidate_start_date, end_date=(date_obj + timedelta(days=10)).strftime("%Y-%m-%d") ) for d in trading_days: if d.date() >= date_obj.date(): return d.strftime("%Y%m%d") raise ValueError(f"No valid trading day found after {candidate_start_date}")
[docs] def get_forecasting_loaders( train_contract_dataset: ContractDataset, val_contract_dataset: ContractDataset, test_contract_dataset: ContractDataset, seq_len: int, pred_len: int, tte_tolerance: Tuple[int, int], core_feats: List[str] = ["option_returns"], tte_feats: Optional[List[str]] = None, datetime_feats: Optional[List[str]] = None, vol_feats: Optional[List[str]] = None, rolling_volatility_range: Optional[List[int]] = None, keep_datetime: bool = False, target_channels: Optional[List[str]] = None, target_type: str = "multistep", batch_size: int = 32, shuffle: bool = True, drop_last: bool = False, num_workers: int = 4, prefetch_factor: Optional[int] = None, pin_memory: bool = torch.cuda.is_available(), persistent_workers: bool = True, clean_up: bool = False, offline: bool = False, save_dir: Optional[str] = None, verbose: bool = False, scaling: bool = False, intraday: bool = False, dtype: str = "float32", normalize_target: bool = False, modify_contracts: bool = False, warning: bool = True, dev_mode: bool = False, ) -> Union[ Tuple[DataLoader, DataLoader, DataLoader, None], Tuple[DataLoader, DataLoader, DataLoader, StandardScaler], ]: """ Forms training, validation, and test dataloaders for option contract data. Args: train_contract_dataset: Contract dataset for training val_contract_dataset: Contract dataset for validation test_contract_dataset: Contract dataset for testing seq_len: Sequence length for input data pred_len: Prediction length for forecasting tte_tolerance: Tuple of (min, max) time to expiration tolerance in minutes core_feats: List of core features to include tte_feats: List of time-to-expiration features to include datetime_feats: List of datetime features to include keep_datetime: Whether to keep the datetime column in the dataset target_type: Type of forecasting target. Options: "multistep" (float), "average" (float), or "average_direction" (binary). batch_size: Number of samples per batch shuffle: Whether to shuffle the data drop_last: Whether to drop the last incomplete batch num_workers: Number of subprocesses to use for data loading prefetch_factor: Number of batches to prefetch pin_memory: Whether to pin memory for faster GPU transfer clean_up: Whether to clean up the data after use offline: Whether to load saved contracts from disk save_dir: Directory to save/load processed datasets modify_contracts: Whether to modify contracts if they are invalid in get_forecasting_dataset function calls. verbose: Whether to print verbose output scaling: Whether to normalize the datasets intraday: Whether to use intraday data target_channels: List of target channels for forecasting dtype: Data type for tensors normalize_target: Whether to normalize the target variable(s) warning: Whether to show warnings dev_mode: Whether to run in development mode Returns: Tuple[DataLoader, DataLoader, DataLoader]: Train, validation, and test data loaders if scaling=False. Tuple[DataLoader, DataLoader, DataLoader, StandardScaler]: Train, validation, and test data loaders, and the scaler if scaling=True. """ # Get the combined datasets of contract data for training, validation, and testing train_dataset, _ = get_forecasting_dataset( contract_dataset=train_contract_dataset, tte_tolerance=tte_tolerance, seq_len=seq_len, pred_len=pred_len, core_feats=core_feats, tte_feats=tte_feats, vol_feats=vol_feats, rolling_volatility_range=rolling_volatility_range, datetime_feats=datetime_feats, keep_datetime=keep_datetime, target_type=target_type, clean_up=clean_up, offline=offline, intraday=intraday, target_channels=target_channels, dtype=dtype, normalize_target=normalize_target, modify_contracts=modify_contracts, save_dir=save_dir, verbose=verbose, warning=warning, dev_mode=dev_mode, ) val_dataset, _ = get_forecasting_dataset( contract_dataset=val_contract_dataset, tte_tolerance=tte_tolerance, seq_len=seq_len, pred_len=pred_len, core_feats=core_feats, tte_feats=tte_feats, datetime_feats=datetime_feats, vol_feats=vol_feats, rolling_volatility_range=rolling_volatility_range, keep_datetime=keep_datetime, target_type=target_type, clean_up=clean_up, offline=offline, intraday=intraday, target_channels=target_channels, dtype=dtype, normalize_target=normalize_target, modify_contracts=modify_contracts, save_dir=save_dir, verbose=verbose, warning=warning, dev_mode=dev_mode, ) test_dataset, _ = get_forecasting_dataset( contract_dataset=test_contract_dataset, tte_tolerance=tte_tolerance, seq_len=seq_len, pred_len=pred_len, core_feats=core_feats, tte_feats=tte_feats, datetime_feats=datetime_feats, vol_feats=vol_feats, rolling_volatility_range=rolling_volatility_range, keep_datetime=keep_datetime, target_type=target_type, clean_up=clean_up, offline=offline, intraday=intraday, target_channels=target_channels, dtype=dtype, normalize_target=normalize_target, modify_contracts=modify_contracts, save_dir=save_dir, verbose=verbose, warning=warning, dev_mode=dev_mode, ) if scaling: train_dataset, val_dataset, test_dataset, scaler = normalize_datasets( train_dataset=train_dataset, val_dataset=val_dataset, test_dataset=test_dataset, ) else: scaler = None # Create dataloaders for training, validation, and testing train_loader = DataLoader( dataset=train_dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers, persistent_workers=persistent_workers, prefetch_factor=prefetch_factor, pin_memory=pin_memory, ) val_loader = DataLoader( dataset=val_dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers, persistent_workers=persistent_workers, prefetch_factor=prefetch_factor, pin_memory=pin_memory, ) test_loader = DataLoader( dataset=test_dataset, batch_size=batch_size, shuffle=False, drop_last=drop_last, num_workers=num_workers, persistent_workers=persistent_workers, prefetch_factor=prefetch_factor, pin_memory=pin_memory, ) return train_loader, val_loader, test_loader, scaler
[docs] def create_windows( df: pd.DataFrame, seq_len: int, pred_len: int, window_stride: int, intraday: bool = False, ) -> Tuple[np.ndarray, np.ndarray]: """ Generates rolling windows of data for a given DataFrame. Should be used primarily for scikit-learn models and/or intraday modeling, otherwise default to optrade.data.forecasing.get_forecasting_loaders or optrade.data.forecasting.get_forecasting_datasets. Args: df (pd.DataFrame): DataFrame containing the data. seq_len (int): Length of the input sequence. pred_len (int): Length of the prediction sequence. window_stride (int): Number of steps to move the window forward. intraday (bool): Whether the data is intraday or not. If True, the function will first split the data into separate trading days before creating individual windows that cannot crossover between days. Otherwise, the function will create windows that can span multiple days. Returns: input (np.ndarray): Array of input windows of shape (num_windows, seq_len, num_features) where num_features is the number of columns in the DataFrame (removing datetime but adding returns). target (np.ndarray): Array of target windows of shape (num_windows, pred_len, 1). Target contains only returns for the 'option_mid_price'. """ datetime = df["datetime"] df_copy = df.copy() # Define input features (all columns except datetime) feature_columns = [col for col in df_copy.columns if col != "datetime"] inputs, targets = [], [] if intraday: # Get all unique days days = datetime.dt.date.unique() # Iterate through each day independently for day in days: day_mask = datetime.dt.date == day day_data = df_copy.loc[day_mask].copy() print(f"Length of data: {len(day_data)}") # Since returns will be part of the input features but don't exist for 9:30am # we remove the market open (9:30am) of each day first_time = day_data["datetime"].iloc[0].time() if first_time.hour == 9 and first_time.minute == 30: day_data = day_data.iloc[1:].reset_index(drop=True) print(f"Day data after removing 9:30am: {day_data.tail()}") # Raise an error if the length of the day is less than the sum of seq_len+pred_len if len(day_data) < seq_len + pred_len: raise ValueError( f"seq_len + pred_len = {seq_len + pred_len} exceeds the length of the day. \ Either set intraday=False or reduce seq_len and/or pred_len." ) # Get input features and targets, convert to NumPy arrays day_features = day_data[feature_columns].to_numpy() day_targets = day_data["option_returns"].to_numpy().reshape(-1, 1) # Apply the sliding window technique to obtain windows for inputs and targets for i in range(0, len(day_data) - seq_len - pred_len + 1, window_stride): inputs.append(day_features[i : i + seq_len]) targets.append(day_targets[i + seq_len : i + seq_len + pred_len]) else: # Since returns will be part of the input features but don't exist for first market open # i.e. 9:30am on the first day, we remove it first_time = datetime.iloc[0].time() if first_time.hour == 9 and first_time.minute == 30: df_copy = df_copy.iloc[1:].reset_index(drop=True) # Extract features and targets features = df_copy[feature_columns].to_numpy() targets_data = df_copy["option_returns"].to_numpy().reshape(-1, 1) # Create windows for i in range(0, len(df_copy) - seq_len - pred_len + 1, window_stride): inputs.append(features[i : i + seq_len]) targets.append(targets_data[i + seq_len : i + seq_len + pred_len]) # Convert to numpy arrays return np.array(inputs), np.array(targets)
if __name__ == "__main__": # Test: get_forecasting_loaders root = "AMZN" total_start_date = "20230101" total_end_date = "20230901" right = "C" interval_min = 60 contract_stride = 3 target_tte = 30 tte_tolerance = (15, 45) moneyness = "ATM" volatility_scaled = True volatility_scalar = 0.01 volatility_type = "period" strike_band = 0.05 normalize_target = True # TTE features tte_feats = ["sqrt", "exp_decay"] # Datetime features datetime_feats = [ "sin_minute_of_day", "cos_minute_of_day", "sin_hour_of_week", "cos_hour_of_week", ] # Select features core_feats = [ "log_option_returns", "log_stock_returns", "option_returns", "stock_returns", "distance_to_strike", "moneyness", # "option_lob_imbalance", # "option_quote_spread", "stock_lob_imbalance", "stock_quote_spread", "option_mid_price", "option_bid_size", "option_bid", "option_ask_size", "option_close", "option_volume", "option_count", "stock_mid_price", "stock_bid_size", "stock_bid", "stock_ask_size", "stock_ask", "stock_volume", "stock_count", ] # Testing: get_loaders from optrade.data.contracts import get_contract_datasets train_cd, val_cd, test_cd = get_contract_datasets( root=root, start_date=total_start_date, end_date=total_end_date, contract_stride=contract_stride, interval_min=interval_min, right=right, target_tte=target_tte, tte_tolerance=tte_tolerance, moneyness=moneyness, strike_band=strike_band, volatility_type=volatility_type, volatility_scaled=volatility_scaled, volatility_scalar=volatility_scalar, verbose=True, train_split=0.5, val_split=0.25, dev_mode=True, offline=True ) output = get_forecasting_loaders( train_contract_dataset=train_cd, val_contract_dataset=val_cd, test_contract_dataset=test_cd, tte_tolerance=tte_tolerance, keep_datetime=True, target_type="average", seq_len=100, pred_len=10, normalize_target=normalize_target, core_feats=core_feats, tte_feats=tte_feats, datetime_feats=datetime_feats, batch_size=32, modify_contracts=True, clean_up=False, offline=True, save_dir=None, verbose=True, scaling=True, dev_mode=True, ) train_loader, val_loader, test_loader = output[0:3] print(f"Num train examples: {len(train_loader.dataset)}") print(f"Num val examples: {len(val_loader.dataset)}") print(f"Num test examples: {len(test_loader.dataset)}") for batch in train_loader: x, y, x_dt, y_dt = batch print(f"x ({x.dtype}) shape: {x.shape}. y ({y.dtype}) shape: {y.shape}") print(f"x_dt (before): {x_dt.shape}. y_dt (before): {y_dt.shape}") x_dt = tensor_to_datetime(timestamp_tensor=x_dt, batch_mode=True) y_dt = tensor_to_datetime(timestamp_tensor=y_dt, batch_mode=True) print(f"x_dt: {x_dt}. y_dt: {y_dt}") break # # Testing: create_windows # from optrade.data.features import transform_features # from optrade.data.contracts import Contract # from rich.console import Console # console = Console() # contract = Contract.find_optimal( # root="AAPL", # start_date="20241107", # volatility_scaled=False, # strike_band=0.05, # moneyness="OTM", # interval_min=1, # right="C", # target_tte=30, # tte_tolerance=(25, 35), # ) # df = contract.load_data(clean_up=True, offline=False, warning=True) # # TTE features # tte_feats = ["sqrt", "exp_decay"] # # Datetime features # datetime_feats = [ # "sin_minute_of_day", # "cos_minute_of_day", # "sin_hour_of_week", # "cos_hour_of_week", # ] # # Select features # core_feats = [ # "option_returns", # "stock_returns", # "distance_to_strike", # "moneyness", # "option_lob_imbalance", # "option_quote_spread", # "stock_lob_imbalance", # "stock_quote_spread", # "option_mid_price", # "option_bid_size", # "option_bid", # "option_ask_size", # "option_close", # "option_volume", # "option_count", # "stock_mid_price", # "stock_bid_size", # "stock_bid", # "stock_ask_size", # "stock_ask", # "stock_volume", # "stock_count", # ] # df = transform_features( # df=df, # core_feats=core_feats, # tte_feats=tte_feats, # datetime_feats=datetime_feats, # strike=contract.strike, # exp=contract.exp, # keep_datetime=True, # ) # print(df.columns) # x, y = create_windows( # df=df, seq_len=30, pred_len=6, window_stride=1, intraday=False # ) # print(x.shape, y.shape)