Spaces:
Build error
Build error
import streamlit as st | |
import yfinance as yf | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from sklearn.model_selection import train_test_split | |
from tensortrade.env import TradingEnvironment | |
from tensortrade.data import DataFeed, Stream | |
from tensortrade.strategies import StableBaselinesTradingStrategy | |
from tensortrade.agents import PPOAgent, A2CAgent, DQNAgent | |
from tensortrade.instruments import USD, BTC | |
from tensortrade.wallets import Wallet, Portfolio | |
import json | |
# تنظیمات Streamlit | |
st.set_page_config(page_title="TensorTrade Pro", layout="wide") | |
st.title("TensorTrade Advanced Trading System") | |
# ======================== تنظیمات سایدبار ======================== | |
with st.sidebar: | |
st.header("📊 تنظیمات داده") | |
data_source = st.radio("منبع داده:", ["Yahoo Finance", "آپلود فایل CSV"]) | |
if data_source == "Yahoo Finance": | |
ticker = st.text_input("نماد (مثال: BTC-USD)", "BTC-USD") | |
start_date = st.date_input("تاریخ شروع", pd.to_datetime("2020-01-01")) | |
end_date = st.date_input("تاریخ پایان", pd.to_datetime("2023-01-01")) | |
else: | |
uploaded_file = st.file_uploader("فایل CSV را آپلود کنید", type="csv") | |
date_col = st.text_input("ستون تاریخ (Date)", "Date") | |
price_col = st.text_input("ستون قیمت پایانی (Close)", "Close") | |
st.header("⚙️ تقسیم داده") | |
train_size = st.slider("درصد داده آموزش", 50, 90, 70) | |
val_size = st.slider("درصد داده اعتبارسنجی", 5, 30, 15) | |
test_size = 100 - train_size - val_size | |
st.write(f"داده تست: {test_size}%") | |
st.header("🧠 پارامترهای مدل") | |
model_name = st.selectbox("الگوریتم", ["PPO", "A2C", "DQN"]) | |
st.subheader("ساختار شبکه عصبی") | |
num_layers = st.slider("تعداد لایههای پنهان", 1, 5, 2) | |
hidden_units = st.text_input("تعداد نورونها در هر لایه (جدا با کاما)", "64, 32") | |
activation = st.selectbox("تابع فعالسازی", ["relu", "tanh", "sigmoid"]) | |
learning_rate = st.number_input("نرخ یادگیری", 0.0001, 0.1, 0.001, step=0.0001) | |
st.header("📉 پارامترهای بکتست") | |
initial_balance = st.number_input("موجودی اولیه (USD)", 1000, 1000000, 10000) | |
# ======================== پردازش داده ======================== | |
def load_data(): | |
if data_source == "Yahoo Finance": | |
data = yf.download(ticker, start=start_date, end=end_date) | |
data = data[['Open', 'High', 'Low', 'Close', 'Volume']] | |
else: | |
data = pd.read_csv(uploaded_file) | |
data[date_col] = pd.to_datetime(data[date_col]) | |
data.set_index(date_col, inplace=True) | |
data = data[[price_col]] | |
data.columns = ['Close'] | |
train_data, temp_data = train_test_split(data, train_size=train_size/100, shuffle=False) | |
val_data, test_data = train_test_split(temp_data, test_size=test_size/(test_size+val_size), shuffle=False) | |
return train_data, val_data, test_data | |
try: | |
train_data, val_data, test_data = load_data() | |
st.success("✅ دادهها با موفقیت بارگذاری شدند!") | |
st.subheader("📈 نمودار قیمت") | |
fig, ax = plt.subplots() | |
ax.plot(train_data['Close'], label='آموزش') | |
ax.plot(val_data['Close'], label='اعتبارسنجی') | |
ax.plot(test_data['Close'], label='تست') | |
ax.legend() | |
st.pyplot(fig) | |
# ======================== ساخت محیط معاملاتی ======================== | |
def create_environment(data): | |
price = Stream('close', data['Close'].values) | |
feed = DataFeed([price]) | |
portfolio = Portfolio( | |
USD, | |
wallets=[ | |
Wallet(exchange=None, instrument=USD, balance=initial_balance), | |
Wallet(exchange=None, instrument=BTC, balance=0) | |
] | |
) | |
return TradingEnvironment( | |
portfolio=portfolio, | |
feed=feed, | |
window_size=20, | |
action_scheme='discrete', | |
reward_scheme='risk-adjusted' | |
) | |
env_train = create_environment(train_data) | |
env_val = create_environment(val_data) | |
env_test = create_environment(test_data) | |
# ======================== آموزش مدل ======================== | |
if st.button("🚀 شروع آموزش و بکتست"): | |
net_arch = [int(x.strip()) for x in hidden_units.split(',')] | |
agent_class = { | |
"PPO": PPOAgent, | |
"A2C": A2CAgent, | |
"DQN": DQNAgent | |
}[model_name] | |
strategy = StableBaselinesTradingStrategy( | |
environment=env_train, | |
agent_class=agent_class, | |
agent_params={ | |
'policy': 'MlpPolicy', | |
'policy_kwargs': { | |
'net_arch': net_arch, | |
'activation_fn': eval(f"torch.nn.{activation}") | |
}, | |
'learning_rate': learning_rate | |
} | |
) | |
with st.spinner("در حال آموزش مدل..."): | |
strategy.train(steps=5000) | |
strategy.save("trained_model") | |
with st.spinner("در حال اعتبارسنجی..."): | |
val_performance = strategy.run(env_val) | |
with st.spinner("در حال اجرای بکتست نهایی..."): | |
test_performance = strategy.run(env_test) | |
st.subheader("📊 نتایج نهایی") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.metric("سود/زیان (Train)", f"{env_train.portfolio.performance.profit_loss:.2f}%") | |
with col2: | |
st.metric("سود/زیان (Val)", f"{val_performance['profit_loss']:.2f}%") | |
with col3: | |
st.metric("سود/زیان (Test)", f"{test_performance['profit_loss']:.2f}%") | |
st.subheader("📈 نمودار ارزش پرتفوی") | |
fig, ax = plt.subplots() | |
ax.plot(env_train.portfolio.performance.net_worth, label='آموزش') | |
ax.plot(val_performance['net_worth'], label='اعتبارسنجی') | |
ax.plot(test_performance['net_worth'], label='تست') | |
ax.legend() | |
st.pyplot(fig) | |
st.subheader("💾 ذخیره مدل") | |
with open("trained_model.zip", "rb") as f: | |
st.download_button( | |
label="دانلود مدل آموزشدیده", | |
data=f, | |
file_name="trading_model.zip", | |
mime="application/zip" | |
) | |
config = { | |
"hidden_layers": net_arch, | |
"activation": activation, | |
"learning_rate": learning_rate | |
} | |
st.download_button( | |
label="دانلود تنظیمات مدل", | |
data=json.dumps(config), | |
file_name="model_config.json", | |
mime="application/json" | |
) | |
except Exception as e: | |
st.error(f"❌ خطا: {str(e)}") |