import pickle
from pathlib import Path
import pandas as pd
from typing import Tuple, Iterator, Dict, Any, Optional
from datetime import datetime, timedelta
from rich.console import Console
from pydantic import BaseModel, Field
import pandas_market_calendars as mcal
# Custom modules
from optrade.data.thetadata import load_all_data
from optrade.data.thetadata import find_optimal_exp
from optrade.data.thetadata import find_optimal_strike
from optrade.utils.directories import set_contract_dir
from optrade.utils.error_handlers import DataValidationError, MARKET_HOLIDAY, WEEKEND
from optrade.utils.volatility import get_train_historical_vol
SCRIPT_DIR = Path(__file__).resolve().parent
[docs]
class Contract:
"""
A class representing an options contract with methods for optimal contract selection.
The Contract class defines the structure of an options contract including the underlying security,
dates, strike price, and other key parameters.
"""
[docs]
def __init__(
self,
root: str,
start_date: str,
exp: str,
strike: float,
interval_min: int,
right: str,
):
"""Initialize a Contract instance.
Args:
root: Root symbol of the underlying security (e.g., "AAPL" representing Apple Inc.)
start_date: Start date in YYYYMMDD format (e.g., "20241107" representing November 7, 2024)
exp: Expiration date in YYYYMMDD format (e.g., "20241206" representing December 6, 2024)
strike: Strike price (e.g., 225 representing $225)
interval_min: Interval in minutes (e.g., 1 representing 1 minute)
right: Option type ('C' for call, 'P' for put)
Returns:
None
"""
self.root = root
self.start_date = start_date
self.exp = exp
self.strike = strike
self.interval_min = interval_min
self.right = right
def __eq__(self, other):
"""Check if two contracts are equal."""
if not isinstance(other, Contract):
return False
return (
self.root == other.root
and self.start_date == other.start_date
and self.exp == other.exp
and self.strike == other.strike
and self.interval_min == other.interval_min
and self.right == other.right
)
def __hash__(self):
"""Hash the contract for use in dictionaries and sets."""
return hash(
(
self.root,
self.start_date,
self.exp,
self.strike,
self.interval_min,
self.right,
)
)
def __repr__(self):
"""Returns a string representation of the contract."""
return f"Contract(root='{self.root}', start_date='{self.start_date}', exp='{self.exp}', strike={self.strike}, interval_min={self.interval_min}, right='{self.right}')"
[docs]
@classmethod
def find_optimal(
cls,
root: str,
start_date: str,
interval_min: int,
right: str,
target_tte: int,
tte_tolerance: Tuple[int, int],
moneyness: str,
strike_band: Optional[float] = 0.05,
hist_vol: Optional[float] = None,
volatility_scaled: bool = False,
volatility_scalar: Optional[float] = 1.0,
verbose: bool = True,
warning: bool = False,
dev_mode: bool = False,
) -> "Contract":
"""Find the optimal contract for a given security, start date, and approximate TTE.
Args:
root: Underlying stock symbol
start_date: Start date for the contract in YYYYMMDD format
interval_min: Interval in minutes
right: Option type (C for call, P for put)
target_tte: Target time to expiration in days
tte_tolerance: Acceptable range for TTE as (min_days, max_days)
moneyness: Contract moneyness (OTM, ATM, ITM)
strike_band: Target percentage band for strike selection
hist_vol: Historical volatility for dynamic strike selection
volatility_scaled: Whether to select strike by volatility
volatility_scalar: Scaling factor for volatiliy-based strike selection
verbose: Whether to print verbose output
"""
exp, _ = find_optimal_exp(
root=root,
start_date=start_date,
target_tte=target_tte,
tte_tolerance=tte_tolerance,
clean_up=True,
dev_mode=dev_mode,
)
ctx = Console()
# Validate if start_date is a trading day
date_obj = datetime.strptime(start_date, "%Y%m%d")
# Check if it's a weekend
if date_obj.weekday() >= 5: # 5 = Saturday, 6 = Sunday
raise DataValidationError(
message=f"Start date {start_date} falls on a weekend. Markets are closed.",
error_code=WEEKEND,
verbose=verbose,
warning=warning,
)
# Check if it's a market holiday
nyse = mcal.get_calendar("NYSE")
trading_days = nyse.valid_days(start_date=start_date, end_date=start_date)
if len(trading_days) == 0:
raise DataValidationError(
message=f"Start date {start_date} is a market holiday. Markets are closed.",
error_code=MARKET_HOLIDAY,
verbose=verbose,
warning=warning,
)
strike = find_optimal_strike(
root=root,
start_date=start_date,
exp=exp,
right=right,
interval_min=interval_min,
moneyness=moneyness,
strike_band=strike_band,
hist_vol=hist_vol,
volatility_scaled=volatility_scaled,
volatility_scalar=volatility_scalar,
clean_up=True,
dev_mode=dev_mode,
)
if verbose:
ctx.log(
f"Identified optimal contract with strike price of ${strike} expiring on {exp}"
)
return cls(
root=root,
start_date=start_date,
exp=exp,
strike=strike,
interval_min=interval_min,
right=right,
)
[docs]
def load_data(
self,
clean_up: bool = False,
offline: bool = False,
save_dir: Optional[str] = None,
warning: bool = False,
dev_mode: bool = False,
) -> pd.DataFrame:
"""Load data for the selected contract.
Args:
clean_up: Whether to clean up the data after use
offline: Whether to load saved data from disk
save_dir: Directory to save/load data
warning: Whether to display warnings
dev_mode: Whether to use development mode
Returns:
pd.DataFrame: The loaded data containing NBBO quotes and OHLCVC data for the
contract and the underlying
"""
return load_all_data(
root=self.root,
start_date=self.start_date,
exp=self.exp,
interval_min=self.interval_min,
right=self.right,
strike=self.strike,
clean_up=clean_up,
offline=offline,
save_dir=save_dir,
warning=warning,
dev_mode=dev_mode,
)
[docs]
class ContractDataset:
"""
A dataset containing options contracts generated with consistent parameters.
"""
[docs]
def __init__(
self,
root: str,
total_start_date: str,
total_end_date: str,
contract_stride: int,
interval_min: int,
right: str,
target_tte: int,
tte_tolerance: Tuple[int, int],
moneyness: str,
strike_band: float = 0.05,
volatility_scaled: bool = False,
volatility_scalar: float = 1.0,
hist_vol: Optional[float] = None,
verbose: bool = False,
save_dir: Optional[str] = None,
warning: bool = True,
dev_mode: bool = False,
contract_dir: Optional[Path] = None,
) -> None:
"""
Initialize the ContractDataset with the specified parameters.
Args:
root: The security root symbol
total_start_date: Start date for the dataset (YYYYMMDD)
total_end_date: End date for the dataset (YYYYMMDD)
contract_stride: Days between consecutive contracts
interval_min: Data interval in minutes
right: Option type (C/P)
target_tte: Target time to expiration in days
tte_tolerance: Acceptable range for TTE as (min_days, max_days)
moneyness: Contract moneyness (OTM/ATM/ITM)
strike_band: Target percentage band for strike selection
volatility_scaled: Whether to scale by volatility
volatility_scalar: Scaling factor for volatility
hist_vol: Historical volatility for dynamic strike selection
verbose: Whether to print verbose output
"""
self.root = root
self.total_start_date = total_start_date
self.total_end_date = total_end_date
self.contract_stride = contract_stride
self.interval_min = interval_min
self.right = right
self.target_tte = target_tte
self.tte_tolerance = tte_tolerance
self.moneyness = moneyness
self.strike_band = strike_band
self.volatility_scaled = volatility_scaled
self.volatility_scalar = volatility_scalar
self.hist_vol = hist_vol
self.verbose = verbose
self.warning = warning
self.dev_mode = dev_mode
self.contracts = []
if contract_dir is None:
self.contract_dir = set_contract_dir(
SCRIPT_DIR=SCRIPT_DIR,
root=root,
start_date=total_start_date,
end_date=total_end_date,
contract_stride=contract_stride,
interval_min=interval_min,
right=right,
target_tte=target_tte,
tte_tolerance=tte_tolerance,
moneyness=moneyness,
strike_band=strike_band,
volatility_scaled=volatility_scaled,
volatility_scalar=volatility_scalar,
hist_vol=hist_vol,
save_dir=save_dir,
dev_mode=dev_mode,
)
else:
self.contract_dir = Path(contract_dir)
[docs]
def generate(self) -> "ContractDataset":
"""
Generate all contracts in the dataset based on configuration parameters. Contracts are
generated by starting from total_start_date and advancing by contract_stride days until
reaching the last valid date that allows for contracts within the specified time-to-expiration
tolerance.
Returns:
ContractDataset: The dataset with all generated contracts
"""
ctx = Console()
# Parse dates
start_date = datetime.strptime(self.total_start_date, "%Y%m%d")
end_date = datetime.strptime(self.total_end_date, "%Y%m%d")
max_tte = max(self.tte_tolerance)
# Calculate the latest possible start date
latest_start = end_date - timedelta(days=max_tte)
# Generate contracts
current_date = start_date
while current_date <= latest_start:
# Format initial date string
date_str = current_date.strftime("%Y%m%d")
attempt_date = current_date
contract = None
# Find a valid contract for the current date. Some dates may be ineligible due to holidays or weekends.
while contract is None and attempt_date <= latest_start:
attempt_date_str = attempt_date.strftime("%Y%m%d")
try:
contract = Contract.find_optimal(
root=self.root,
start_date=attempt_date_str,
interval_min=self.interval_min,
right=self.right,
target_tte=self.target_tte,
tte_tolerance=self.tte_tolerance,
moneyness=self.moneyness,
strike_band=self.strike_band,
hist_vol=self.hist_vol,
volatility_scaled=self.volatility_scaled,
volatility_scalar=self.volatility_scalar,
verbose=self.verbose,
warning=self.warning,
dev_mode=self.dev_mode,
)
if attempt_date > current_date:
(
ctx.log(
f"Found valid contract at shifted date: {attempt_date_str}"
)
if self.verbose
else None
)
except DataValidationError as e:
if e.error_code == WEEKEND:
(
ctx.log(f"Skipping weekend: {attempt_date_str}")
if self.verbose
else None
)
attempt_date += timedelta(days=1)
elif e.error_code == MARKET_HOLIDAY:
(
ctx.log(f"Skipping market holiday: {attempt_date_str}")
if self.verbose
else None
)
attempt_date += timedelta(days=1)
else:
(
ctx.log(
f"Unkown error: {str(e)}. Skipping date: {attempt_date_str}."
)
if self.verbose
else None
)
# Check if we've run out of valid dates
if attempt_date > latest_start:
(
ctx.log(
f"Unable to find valid contract starting from {date_str}"
)
if self.verbose
else None
)
break
continue
# If we found a valid contract, add it and advance by stride
if contract is not None:
self.contracts.append(contract)
ctx.log(f"Added contract: {contract}") if self.verbose else None
current_date = attempt_date + timedelta(days=self.contract_stride)
else:
# If no contract was found, advance by one day to try the next period
current_date += timedelta(days=1)
(
ctx.log(f"Next start date: {current_date.strftime('%Y%m%d')}")
if self.verbose
else None
)
return self
def __len__(self) -> int:
"""Get the number of contracts in the dataset.
Returns:
int: Number of contracts
"""
return len(self.contracts)
def __getitem__(self, idx) -> Contract:
"""Get a contract by index.
Returns:
Contract: The contract at the specified index
"""
return self.contracts[idx]
def __iter__(self) -> Iterator:
"""Iterate through contracts.
Returns:
Iterator: Iterator for contracts
"""
return iter(self.contracts)
[docs]
def save(self, filename: Optional[str] = None, clean_file: bool = False) -> None:
"""Save the dataset to a pickle file.
Args:
filepath: Optional custom filepath. If None, generates default name
clean_file: Whether to delete the existing file if it exists
Returns:
str: Path where the pickle file was saved
"""
self.contract_dir.mkdir(parents=True, exist_ok=True)
if not hasattr(self, "filepath"):
self.filepath = (
self.contract_dir / "contracts.pkl"
if filename is None
else self.contract_dir / filename
)
elif clean_file and self.filepath.exists():
self.filepath.unlink()
# Just pickle the whole dataset with contracts as-is
with open(self.filepath, "wb") as f:
pickle.dump(self, f)
if self.verbose:
ctx = Console()
ctx.log(f'Contract dataset saved to "{self.filepath}"')
[docs]
@classmethod
def load(cls, filepath: Path) -> "ContractDataset":
"""Load a dataset from a pickle file."""
ctx = Console()
with open(filepath, "rb") as f:
instance = pickle.load(f)
if instance.verbose:
ctx.log(f"Contract dataset loaded from {filepath}")
return instance
[docs]
def get_contract_datasets(
root: str,
start_date: str,
end_date: str,
contract_stride: int,
interval_min: int,
right: str,
target_tte: int,
tte_tolerance: Tuple[int, int],
moneyness: str,
strike_band: Optional[float] = 0.05,
volatility_type: Optional[str] = "period",
volatility_scaled: Optional[bool] = False,
volatility_scalar: Optional[float] = 1.0,
train_split: float = 0.7,
val_split: float = 0.1,
clean_up: bool = False,
offline: bool = False,
save_dir: Optional[str] = None,
verbose: bool = False,
dev_mode: bool = False,
) -> Tuple[ContractDataset, ContractDataset, ContractDataset]:
"""
Returns the training, validation, and test datasets contract datasets. These contain mutually exclusive contracts
at mutually exclusive time periods to prevent information leakage during training and evaluation.
Args:
root: Underlying stock symbol
start_date: Start date for the total dataset in YYYYMMDD format
end_date: End date for the total dataset in YYYYMMDD format
contract_stride: Number of days between each contract
interval_min: Interval in minutes for the underlying stock data
right: Option type (C for call, P for put)
target_tte: Target time to expiration in days
tte_tolerance: Tuple of (min, max) time to expiration tolerance in days
moneyness: Moneyness of the option contract (OTM, ATM, ITM)
strike_band: Target band for moneyness selection, proportion of current underlying price
volatility_type: Type of historical volatility to use
volatility_scaled: Whether to scale strikes based on historical volatility
volatility_scalar: Scalar to adjust historical volatility-based strike selection
train_split: Proportion of total days to use for training
val_split: Proportion of total days to use for validation
clean_up: Whether to clean up the data after use
offline: Whether to load saved contracts from disk
save_dir: Directory to save/load contracts
verbose: Whether to print verbose output
dev_mode: Whether to use development mode
Returns:
Training, validation, and test contract datasets.
"""
ctx = Console()
# Volatility-based selection of strikes (Optional)
if volatility_scaled:
ctx.log(f"Using volatility-scaled strike selection with {volatility_type} type")
hist_vol = get_train_historical_vol(
root=root,
start_date=start_date,
end_date=end_date,
interval_min=interval_min,
volatility_window=train_split, # Use the only training data to compute historical volatility
volatility_type=volatility_type,
)
else:
hist_vol = 0.0
contract_dir = set_contract_dir(
SCRIPT_DIR=SCRIPT_DIR,
root=root,
start_date=start_date,
end_date=end_date,
contract_stride=contract_stride,
interval_min=interval_min,
right=right,
target_tte=target_tte,
tte_tolerance=tte_tolerance,
moneyness=moneyness,
strike_band=strike_band,
volatility_scaled=volatility_scaled,
volatility_scalar=volatility_scalar,
hist_vol=hist_vol,
save_dir=save_dir,
dev_mode=dev_mode,
)
# Offline loading (if already saved)
if offline:
with ctx.status("Loading ContractDataset objects (offline)"):
if not all(
(
(contract_dir / "train_contracts.pkl").exists(),
(contract_dir / "val_contracts.pkl").exists(),
(contract_dir / "test_contracts.pkl").exists(),
)
):
raise FileNotFoundError(f"Missing contract files in {contract_dir}")
train_contracts = ContractDataset.load(contract_dir / "train_contracts.pkl")
val_contracts = ContractDataset.load(contract_dir / "val_contracts.pkl")
test_contracts = ContractDataset.load(contract_dir / "test_contracts.pkl")
return train_contracts, val_contracts, test_contracts
# Get contiguous training, validation, and test (start_date, end_date) pairs in YYYYMMDD format
total_days = (
pd.to_datetime(end_date, format="%Y%m%d")
- pd.to_datetime(start_date, format="%Y%m%d")
).days
num_train_days = int(train_split * total_days)
num_val_days = int(val_split * total_days)
train_end_date = (
pd.to_datetime(start_date, format="%Y%m%d") + pd.Timedelta(days=num_train_days)
).strftime("%Y%m%d")
val_end_date = (
pd.to_datetime(train_end_date, format="%Y%m%d")
+ pd.Timedelta(days=num_val_days)
).strftime("%Y%m%d")
test_start_date = (
pd.to_datetime(val_end_date, format="%Y%m%d") + pd.Timedelta(days=1)
).strftime("%Y%m%d")
train_dates = (start_date, train_end_date)
val_dates = (train_end_date, val_end_date)
test_dates = (test_start_date, end_date)
# Create the training, validation, and test contract datasets
ctx.log("------------CREATING TRAINING CONTRACTS------------") if verbose else None
train_contracts = ContractDataset(
root=root,
total_start_date=train_dates[0],
total_end_date=train_dates[1],
contract_stride=contract_stride,
interval_min=interval_min,
right=right,
target_tte=target_tte,
tte_tolerance=tte_tolerance,
moneyness=moneyness,
strike_band=strike_band,
hist_vol=hist_vol,
volatility_scaled=volatility_scaled,
volatility_scalar=volatility_scalar,
verbose=verbose,
save_dir=save_dir,
contract_dir=contract_dir,
dev_mode=dev_mode,
).generate()
(
ctx.log("------------CREATING VALIDATION CONTRACTS------------")
if verbose
else None
)
val_contracts = ContractDataset(
root=root,
total_start_date=val_dates[0],
total_end_date=val_dates[1],
contract_stride=contract_stride,
interval_min=interval_min,
right=right,
target_tte=target_tte,
tte_tolerance=tte_tolerance,
moneyness=moneyness,
strike_band=strike_band,
hist_vol=hist_vol,
volatility_scaled=volatility_scaled,
volatility_scalar=volatility_scalar,
verbose=verbose,
save_dir=save_dir,
contract_dir=contract_dir,
dev_mode=dev_mode,
).generate()
ctx.log("------------CREATING TEST CONTRACTS------------") if verbose else None
test_contracts = ContractDataset(
root=root,
total_start_date=test_dates[0],
total_end_date=test_dates[1],
contract_stride=contract_stride,
interval_min=interval_min,
right=right,
target_tte=target_tte,
tte_tolerance=tte_tolerance,
moneyness=moneyness,
strike_band=strike_band,
hist_vol=hist_vol,
volatility_scaled=volatility_scaled,
volatility_scalar=volatility_scalar,
verbose=verbose,
save_dir=save_dir,
contract_dir=contract_dir,
dev_mode=dev_mode,
).generate()
# Check that train_contracts, val_contracts, and test_contracts are nonempty
if not all(
(len(train_contracts) > 0, len(val_contracts) > 0, len(test_contracts) > 0)
):
raise ValueError(
f"One or more contract datasets are empty. "
f"Number of train contracts: {len(train_contracts)}, "
f"Number of val contracts: {len(val_contracts)}, "
f"Number of test contracts: {len(test_contracts)}. "
f"Try adjusting contract_stride to sample more contracts or "
f"train_split/val_split ratios for more equal distribution of contracts."
)
if not clean_up:
train_contracts.save("train_contracts.pkl")
val_contracts.save("val_contracts.pkl")
test_contracts.save("test_contracts.pkl")
return train_contracts, val_contracts, test_contracts
if __name__ == "__main__":
# Test: Contract
contract = Contract.find_optimal(
root="AAPL",
start_date="20241107",
interval_min=1,
right="C",
target_tte=30,
tte_tolerance=(25, 35),
moneyness="OTM",
strike_band=0.05,
volatility_scaled=False,
verbose=True,
)
df = contract.load_data(clean_up=True, offline=False, warning=True, dev_mode=True)
print(df.head())
# Test: get_contract_datasets()
root = "AMZN"
total_start_date = "20230101"
total_end_date = "20230601"
right = "C"
interval_min = 60
contract_stride = 5
target_tte = 30
tte_tolerance = (15, 45)
moneyness = "ATM"
volatility_scaled = True
volatility_scalar = 0.01
volatility_type = "period"
strike_band = 0.05
train_contract_dataset, val_contract_dataset, test_contract_dataset = (
get_contract_datasets(
root=root,
start_date=total_start_date,
end_date=total_end_date,
contract_stride=contract_stride,
interval_min=interval_min,
right=right,
target_tte=target_tte,
tte_tolerance=tte_tolerance,
moneyness=moneyness,
strike_band=strike_band,
volatility_type=volatility_type,
volatility_scaled=volatility_scaled,
volatility_scalar=volatility_scalar,
train_split=0.4,
val_split=0.3,
clean_up=True,
offline=False,
save_dir=None,
verbose=True,
)
)
train_contract_dataset.save("train_contracts.pkl")
from rich.console import Console
ctx = Console()
ctx.log(f"Contract dir: {train_contract_dataset.contract_dir}")
new_contract_dataset = ContractDataset.load(
train_contract_dataset.contract_dir / "train_contracts.pkl"
)