Spaces:
Build error
Build error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import yfinance as yf
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from sklearn.model_selection import train_test_split
|
7 |
+
from tensortrade.env import TradingEnvironment
|
8 |
+
from tensortrade.data import DataFeed, Stream
|
9 |
+
from tensortrade.strategies import StableBaselinesTradingStrategy
|
10 |
+
from tensortrade.agents import PPOAgent, A2CAgent, DQNAgent
|
11 |
+
from tensortrade.instruments import USD, BTC
|
12 |
+
from tensortrade.wallets import Wallet, Portfolio
|
13 |
+
import json
|
14 |
+
|
15 |
+
# تنظیمات Streamlit
|
16 |
+
st.set_page_config(page_title="TensorTrade Pro", layout="wide")
|
17 |
+
st.title("TensorTrade Advanced Trading System")
|
18 |
+
|
19 |
+
# ======================== تنظیمات سایدبار ========================
|
20 |
+
with st.sidebar:
|
21 |
+
st.header("📊 تنظیمات داده")
|
22 |
+
data_source = st.radio("منبع داده:", ["Yahoo Finance", "آپلود فایل CSV"])
|
23 |
+
|
24 |
+
if data_source == "Yahoo Finance":
|
25 |
+
ticker = st.text_input("نماد (مثال: BTC-USD)", "BTC-USD")
|
26 |
+
start_date = st.date_input("تاریخ شروع", pd.to_datetime("2020-01-01"))
|
27 |
+
end_date = st.date_input("تاریخ پایان", pd.to_datetime("2023-01-01"))
|
28 |
+
else:
|
29 |
+
uploaded_file = st.file_uploader("فایل CSV را آپلود کنید", type="csv")
|
30 |
+
date_col = st.text_input("ستون تاریخ (Date)", "Date")
|
31 |
+
price_col = st.text_input("ستون قیمت پایانی (Close)", "Close")
|
32 |
+
|
33 |
+
st.header("⚙️ تقسیم داده")
|
34 |
+
train_size = st.slider("درصد داده آموزش", 50, 90, 70)
|
35 |
+
val_size = st.slider("درصد داده اعتبارسنجی", 5, 30, 15)
|
36 |
+
test_size = 100 - train_size - val_size
|
37 |
+
st.write(f"داده تست: {test_size}%")
|
38 |
+
|
39 |
+
st.header("🧠 پارامترهای مدل")
|
40 |
+
model_name = st.selectbox("الگوریتم", ["PPO", "A2C", "DQN"])
|
41 |
+
|
42 |
+
st.subheader("ساختار شبکه عصبی")
|
43 |
+
num_layers = st.slider("تعداد لایههای پنهان", 1, 5, 2)
|
44 |
+
hidden_units = st.text_input("تعداد نورونها در هر لایه (جدا با کاما)", "64, 32")
|
45 |
+
activation = st.selectbox("تابع فعالسازی", ["relu", "tanh", "sigmoid"])
|
46 |
+
learning_rate = st.number_input("نرخ یادگیری", 0.0001, 0.1, 0.001, step=0.0001)
|
47 |
+
|
48 |
+
st.header("📉 پارامترهای بکتست")
|
49 |
+
initial_balance = st.number_input("موجودی اولیه (USD)", 1000, 1000000, 10000)
|
50 |
+
|
51 |
+
# ======================== پردازش داده ========================
|
52 |
+
@st.cache_data
|
53 |
+
def load_data():
|
54 |
+
if data_source == "Yahoo Finance":
|
55 |
+
data = yf.download(ticker, start=start_date, end=end_date)
|
56 |
+
data = data[['Open', 'High', 'Low', 'Close', 'Volume']]
|
57 |
+
else:
|
58 |
+
data = pd.read_csv(uploaded_file)
|
59 |
+
data[date_col] = pd.to_datetime(data[date_col])
|
60 |
+
data.set_index(date_col, inplace=True)
|
61 |
+
data = data[[price_col]]
|
62 |
+
data.columns = ['Close']
|
63 |
+
|
64 |
+
train_data, temp_data = train_test_split(data, train_size=train_size/100, shuffle=False)
|
65 |
+
val_data, test_data = train_test_split(temp_data, test_size=test_size/(test_size+val_size), shuffle=False)
|
66 |
+
return train_data, val_data, test_data
|
67 |
+
|
68 |
+
try:
|
69 |
+
train_data, val_data, test_data = load_data()
|
70 |
+
st.success("✅ دادهها با موفقیت بارگذاری شدند!")
|
71 |
+
|
72 |
+
st.subheader("📈 نمودار قیمت")
|
73 |
+
fig, ax = plt.subplots()
|
74 |
+
ax.plot(train_data['Close'], label='آموزش')
|
75 |
+
ax.plot(val_data['Close'], label='اعتبارسنجی')
|
76 |
+
ax.plot(test_data['Close'], label='تست')
|
77 |
+
ax.legend()
|
78 |
+
st.pyplot(fig)
|
79 |
+
|
80 |
+
# ======================== ساخت محیط معاملاتی ========================
|
81 |
+
def create_environment(data):
|
82 |
+
price = Stream('close', data['Close'].values)
|
83 |
+
feed = DataFeed([price])
|
84 |
+
|
85 |
+
portfolio = Portfolio(
|
86 |
+
USD,
|
87 |
+
wallets=[
|
88 |
+
Wallet(exchange=None, instrument=USD, balance=initial_balance),
|
89 |
+
Wallet(exchange=None, instrument=BTC, balance=0)
|
90 |
+
]
|
91 |
+
)
|
92 |
+
|
93 |
+
return TradingEnvironment(
|
94 |
+
portfolio=portfolio,
|
95 |
+
feed=feed,
|
96 |
+
window_size=20,
|
97 |
+
action_scheme='discrete',
|
98 |
+
reward_scheme='risk-adjusted'
|
99 |
+
)
|
100 |
+
|
101 |
+
env_train = create_environment(train_data)
|
102 |
+
env_val = create_environment(val_data)
|
103 |
+
env_test = create_environment(test_data)
|
104 |
+
|
105 |
+
# ======================== آموزش مدل ========================
|
106 |
+
if st.button("🚀 شروع آموزش و بکتست"):
|
107 |
+
net_arch = [int(x.strip()) for x in hidden_units.split(',')]
|
108 |
+
|
109 |
+
agent_class = {
|
110 |
+
"PPO": PPOAgent,
|
111 |
+
"A2C": A2CAgent,
|
112 |
+
"DQN": DQNAgent
|
113 |
+
}[model_name]
|
114 |
+
|
115 |
+
strategy = StableBaselinesTradingStrategy(
|
116 |
+
environment=env_train,
|
117 |
+
agent_class=agent_class,
|
118 |
+
agent_params={
|
119 |
+
'policy': 'MlpPolicy',
|
120 |
+
'policy_kwargs': {
|
121 |
+
'net_arch': net_arch,
|
122 |
+
'activation_fn': eval(f"torch.nn.{activation}")
|
123 |
+
},
|
124 |
+
'learning_rate': learning_rate
|
125 |
+
}
|
126 |
+
)
|
127 |
+
|
128 |
+
with st.spinner("در حال آموزش مدل..."):
|
129 |
+
strategy.train(steps=5000)
|
130 |
+
strategy.save("trained_model")
|
131 |
+
|
132 |
+
with st.spinner("در حال اعتبارسنجی..."):
|
133 |
+
val_performance = strategy.run(env_val)
|
134 |
+
|
135 |
+
with st.spinner("در حال اجرای بکتست نهایی..."):
|
136 |
+
test_performance = strategy.run(env_test)
|
137 |
+
|
138 |
+
st.subheader("📊 نتایج نهایی")
|
139 |
+
col1, col2, col3 = st.columns(3)
|
140 |
+
with col1:
|
141 |
+
st.metric("سود/زیان (Train)", f"{env_train.portfolio.performance.profit_loss:.2f}%")
|
142 |
+
with col2:
|
143 |
+
st.metric("سود/زیان (Val)", f"{val_performance['profit_loss']:.2f}%")
|
144 |
+
with col3:
|
145 |
+
st.metric("سود/زیان (Test)", f"{test_performance['profit_loss']:.2f}%")
|
146 |
+
|
147 |
+
st.subheader("📈 نمودار ارزش پرتفوی")
|
148 |
+
fig, ax = plt.subplots()
|
149 |
+
ax.plot(env_train.portfolio.performance.net_worth, label='آموزش')
|
150 |
+
ax.plot(val_performance['net_worth'], label='اعتبارسنجی')
|
151 |
+
ax.plot(test_performance['net_worth'], label='تست')
|
152 |
+
ax.legend()
|
153 |
+
st.pyplot(fig)
|
154 |
+
|
155 |
+
st.subheader("💾 ذخیره مدل")
|
156 |
+
with open("trained_model.zip", "rb") as f:
|
157 |
+
st.download_button(
|
158 |
+
label="دانلود مدل آموزشدیده",
|
159 |
+
data=f,
|
160 |
+
file_name="trading_model.zip",
|
161 |
+
mime="application/zip"
|
162 |
+
)
|
163 |
+
|
164 |
+
config = {
|
165 |
+
"hidden_layers": net_arch,
|
166 |
+
"activation": activation,
|
167 |
+
"learning_rate": learning_rate
|
168 |
+
}
|
169 |
+
st.download_button(
|
170 |
+
label="دانلود تنظیمات مدل",
|
171 |
+
data=json.dumps(config),
|
172 |
+
file_name="model_config.json",
|
173 |
+
mime="application/json"
|
174 |
+
)
|
175 |
+
|
176 |
+
except Exception as e:
|
177 |
+
st.error(f"❌ خطا: {str(e)}")
|