Spaces:
Build error
Build error
| import streamlit as st | |
| import yfinance as yf | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from sklearn.model_selection import train_test_split | |
| from tensortrade.env import TradingEnvironment | |
| from tensortrade.data import DataFeed, Stream | |
| from tensortrade.strategies import StableBaselinesTradingStrategy | |
| from tensortrade.agents import PPOAgent, A2CAgent, DQNAgent | |
| from tensortrade.instruments import USD, BTC | |
| from tensortrade.wallets import Wallet, Portfolio | |
| import json | |
| # تنظیمات Streamlit | |
| st.set_page_config(page_title="TensorTrade Pro", layout="wide") | |
| st.title("TensorTrade Advanced Trading System") | |
| # ======================== تنظیمات سایدبار ======================== | |
| with st.sidebar: | |
| st.header("📊 تنظیمات داده") | |
| data_source = st.radio("منبع داده:", ["Yahoo Finance", "آپلود فایل CSV"]) | |
| if data_source == "Yahoo Finance": | |
| ticker = st.text_input("نماد (مثال: BTC-USD)", "BTC-USD") | |
| start_date = st.date_input("تاریخ شروع", pd.to_datetime("2020-01-01")) | |
| end_date = st.date_input("تاریخ پایان", pd.to_datetime("2023-01-01")) | |
| else: | |
| uploaded_file = st.file_uploader("فایل CSV را آپلود کنید", type="csv") | |
| date_col = st.text_input("ستون تاریخ (Date)", "Date") | |
| price_col = st.text_input("ستون قیمت پایانی (Close)", "Close") | |
| st.header("⚙️ تقسیم داده") | |
| train_size = st.slider("درصد داده آموزش", 50, 90, 70) | |
| val_size = st.slider("درصد داده اعتبارسنجی", 5, 30, 15) | |
| test_size = 100 - train_size - val_size | |
| st.write(f"داده تست: {test_size}%") | |
| st.header("🧠 پارامترهای مدل") | |
| model_name = st.selectbox("الگوریتم", ["PPO", "A2C", "DQN"]) | |
| st.subheader("ساختار شبکه عصبی") | |
| num_layers = st.slider("تعداد لایههای پنهان", 1, 5, 2) | |
| hidden_units = st.text_input("تعداد نورونها در هر لایه (جدا با کاما)", "64, 32") | |
| activation = st.selectbox("تابع فعالسازی", ["relu", "tanh", "sigmoid"]) | |
| learning_rate = st.number_input("نرخ یادگیری", 0.0001, 0.1, 0.001, step=0.0001) | |
| st.header("📉 پارامترهای بکتست") | |
| initial_balance = st.number_input("موجودی اولیه (USD)", 1000, 1000000, 10000) | |
| # ======================== پردازش داده ======================== | |
| def load_data(): | |
| if data_source == "Yahoo Finance": | |
| data = yf.download(ticker, start=start_date, end=end_date) | |
| data = data[['Open', 'High', 'Low', 'Close', 'Volume']] | |
| else: | |
| data = pd.read_csv(uploaded_file) | |
| data[date_col] = pd.to_datetime(data[date_col]) | |
| data.set_index(date_col, inplace=True) | |
| data = data[[price_col]] | |
| data.columns = ['Close'] | |
| train_data, temp_data = train_test_split(data, train_size=train_size/100, shuffle=False) | |
| val_data, test_data = train_test_split(temp_data, test_size=test_size/(test_size+val_size), shuffle=False) | |
| return train_data, val_data, test_data | |
| try: | |
| train_data, val_data, test_data = load_data() | |
| st.success("✅ دادهها با موفقیت بارگذاری شدند!") | |
| st.subheader("📈 نمودار قیمت") | |
| fig, ax = plt.subplots() | |
| ax.plot(train_data['Close'], label='آموزش') | |
| ax.plot(val_data['Close'], label='اعتبارسنجی') | |
| ax.plot(test_data['Close'], label='تست') | |
| ax.legend() | |
| st.pyplot(fig) | |
| # ======================== ساخت محیط معاملاتی ======================== | |
| def create_environment(data): | |
| price = Stream('close', data['Close'].values) | |
| feed = DataFeed([price]) | |
| portfolio = Portfolio( | |
| USD, | |
| wallets=[ | |
| Wallet(exchange=None, instrument=USD, balance=initial_balance), | |
| Wallet(exchange=None, instrument=BTC, balance=0) | |
| ] | |
| ) | |
| return TradingEnvironment( | |
| portfolio=portfolio, | |
| feed=feed, | |
| window_size=20, | |
| action_scheme='discrete', | |
| reward_scheme='risk-adjusted' | |
| ) | |
| env_train = create_environment(train_data) | |
| env_val = create_environment(val_data) | |
| env_test = create_environment(test_data) | |
| # ======================== آموزش مدل ======================== | |
| if st.button("🚀 شروع آموزش و بکتست"): | |
| net_arch = [int(x.strip()) for x in hidden_units.split(',')] | |
| agent_class = { | |
| "PPO": PPOAgent, | |
| "A2C": A2CAgent, | |
| "DQN": DQNAgent | |
| }[model_name] | |
| strategy = StableBaselinesTradingStrategy( | |
| environment=env_train, | |
| agent_class=agent_class, | |
| agent_params={ | |
| 'policy': 'MlpPolicy', | |
| 'policy_kwargs': { | |
| 'net_arch': net_arch, | |
| 'activation_fn': eval(f"torch.nn.{activation}") | |
| }, | |
| 'learning_rate': learning_rate | |
| } | |
| ) | |
| with st.spinner("در حال آموزش مدل..."): | |
| strategy.train(steps=5000) | |
| strategy.save("trained_model") | |
| with st.spinner("در حال اعتبارسنجی..."): | |
| val_performance = strategy.run(env_val) | |
| with st.spinner("در حال اجرای بکتست نهایی..."): | |
| test_performance = strategy.run(env_test) | |
| st.subheader("📊 نتایج نهایی") | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.metric("سود/زیان (Train)", f"{env_train.portfolio.performance.profit_loss:.2f}%") | |
| with col2: | |
| st.metric("سود/زیان (Val)", f"{val_performance['profit_loss']:.2f}%") | |
| with col3: | |
| st.metric("سود/زیان (Test)", f"{test_performance['profit_loss']:.2f}%") | |
| st.subheader("📈 نمودار ارزش پرتفوی") | |
| fig, ax = plt.subplots() | |
| ax.plot(env_train.portfolio.performance.net_worth, label='آموزش') | |
| ax.plot(val_performance['net_worth'], label='اعتبارسنجی') | |
| ax.plot(test_performance['net_worth'], label='تست') | |
| ax.legend() | |
| st.pyplot(fig) | |
| st.subheader("💾 ذخیره مدل") | |
| with open("trained_model.zip", "rb") as f: | |
| st.download_button( | |
| label="دانلود مدل آموزشدیده", | |
| data=f, | |
| file_name="trading_model.zip", | |
| mime="application/zip" | |
| ) | |
| config = { | |
| "hidden_layers": net_arch, | |
| "activation": activation, | |
| "learning_rate": learning_rate | |
| } | |
| st.download_button( | |
| label="دانلود تنظیمات مدل", | |
| data=json.dumps(config), | |
| file_name="model_config.json", | |
| mime="application/json" | |
| ) | |
| except Exception as e: | |
| st.error(f"❌ خطا: {str(e)}") |