khajavi8056 commited on
Commit
306dc69
·
verified ·
1 Parent(s): c76d4e4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -0
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import yfinance as yf
3
+ import pandas as pd
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from sklearn.model_selection import train_test_split
7
+ from tensortrade.env import TradingEnvironment
8
+ from tensortrade.data import DataFeed, Stream
9
+ from tensortrade.strategies import StableBaselinesTradingStrategy
10
+ from tensortrade.agents import PPOAgent, A2CAgent, DQNAgent
11
+ from tensortrade.instruments import USD, BTC
12
+ from tensortrade.wallets import Wallet, Portfolio
13
+ import json
14
+
15
+ # تنظیمات Streamlit
16
+ st.set_page_config(page_title="TensorTrade Pro", layout="wide")
17
+ st.title("TensorTrade Advanced Trading System")
18
+
19
+ # ======================== تنظیمات سایدبار ========================
20
+ with st.sidebar:
21
+ st.header("📊 تنظیمات داده")
22
+ data_source = st.radio("منبع داده:", ["Yahoo Finance", "آپلود فایل CSV"])
23
+
24
+ if data_source == "Yahoo Finance":
25
+ ticker = st.text_input("نماد (مثال: BTC-USD)", "BTC-USD")
26
+ start_date = st.date_input("تاریخ شروع", pd.to_datetime("2020-01-01"))
27
+ end_date = st.date_input("تاریخ پایان", pd.to_datetime("2023-01-01"))
28
+ else:
29
+ uploaded_file = st.file_uploader("فایل CSV را آپلود کنید", type="csv")
30
+ date_col = st.text_input("ستون تاریخ (Date)", "Date")
31
+ price_col = st.text_input("ستون قیمت پایانی (Close)", "Close")
32
+
33
+ st.header("⚙️ تقسیم داده")
34
+ train_size = st.slider("درصد داده آموزش", 50, 90, 70)
35
+ val_size = st.slider("درصد داده اعتبارسنجی", 5, 30, 15)
36
+ test_size = 100 - train_size - val_size
37
+ st.write(f"داده تست: {test_size}%")
38
+
39
+ st.header("🧠 پارامترهای مدل")
40
+ model_name = st.selectbox("الگوریتم", ["PPO", "A2C", "DQN"])
41
+
42
+ st.subheader("ساختار شبکه عصبی")
43
+ num_layers = st.slider("تعداد لایه‌های پنهان", 1, 5, 2)
44
+ hidden_units = st.text_input("تعداد نورون‌ها در هر لایه (جدا با کاما)", "64, 32")
45
+ activation = st.selectbox("تابع فعالسازی", ["relu", "tanh", "sigmoid"])
46
+ learning_rate = st.number_input("نرخ یادگیری", 0.0001, 0.1, 0.001, step=0.0001)
47
+
48
+ st.header("📉 پارامترهای بکتست")
49
+ initial_balance = st.number_input("موجودی اولیه (USD)", 1000, 1000000, 10000)
50
+
51
+ # ======================== پردازش داده ========================
52
+ @st.cache_data
53
+ def load_data():
54
+ if data_source == "Yahoo Finance":
55
+ data = yf.download(ticker, start=start_date, end=end_date)
56
+ data = data[['Open', 'High', 'Low', 'Close', 'Volume']]
57
+ else:
58
+ data = pd.read_csv(uploaded_file)
59
+ data[date_col] = pd.to_datetime(data[date_col])
60
+ data.set_index(date_col, inplace=True)
61
+ data = data[[price_col]]
62
+ data.columns = ['Close']
63
+
64
+ train_data, temp_data = train_test_split(data, train_size=train_size/100, shuffle=False)
65
+ val_data, test_data = train_test_split(temp_data, test_size=test_size/(test_size+val_size), shuffle=False)
66
+ return train_data, val_data, test_data
67
+
68
+ try:
69
+ train_data, val_data, test_data = load_data()
70
+ st.success("✅ داده‌ها با موفقیت بارگذاری شدند!")
71
+
72
+ st.subheader("📈 نمودار قیمت")
73
+ fig, ax = plt.subplots()
74
+ ax.plot(train_data['Close'], label='آموزش')
75
+ ax.plot(val_data['Close'], label='اعتبارسنجی')
76
+ ax.plot(test_data['Close'], label='تست')
77
+ ax.legend()
78
+ st.pyplot(fig)
79
+
80
+ # ======================== ساخت محیط معاملاتی ========================
81
+ def create_environment(data):
82
+ price = Stream('close', data['Close'].values)
83
+ feed = DataFeed([price])
84
+
85
+ portfolio = Portfolio(
86
+ USD,
87
+ wallets=[
88
+ Wallet(exchange=None, instrument=USD, balance=initial_balance),
89
+ Wallet(exchange=None, instrument=BTC, balance=0)
90
+ ]
91
+ )
92
+
93
+ return TradingEnvironment(
94
+ portfolio=portfolio,
95
+ feed=feed,
96
+ window_size=20,
97
+ action_scheme='discrete',
98
+ reward_scheme='risk-adjusted'
99
+ )
100
+
101
+ env_train = create_environment(train_data)
102
+ env_val = create_environment(val_data)
103
+ env_test = create_environment(test_data)
104
+
105
+ # ======================== آموزش مدل ========================
106
+ if st.button("🚀 شروع آموزش و بکتست"):
107
+ net_arch = [int(x.strip()) for x in hidden_units.split(',')]
108
+
109
+ agent_class = {
110
+ "PPO": PPOAgent,
111
+ "A2C": A2CAgent,
112
+ "DQN": DQNAgent
113
+ }[model_name]
114
+
115
+ strategy = StableBaselinesTradingStrategy(
116
+ environment=env_train,
117
+ agent_class=agent_class,
118
+ agent_params={
119
+ 'policy': 'MlpPolicy',
120
+ 'policy_kwargs': {
121
+ 'net_arch': net_arch,
122
+ 'activation_fn': eval(f"torch.nn.{activation}")
123
+ },
124
+ 'learning_rate': learning_rate
125
+ }
126
+ )
127
+
128
+ with st.spinner("در حال آموزش مدل..."):
129
+ strategy.train(steps=5000)
130
+ strategy.save("trained_model")
131
+
132
+ with st.spinner("در حال اعتبارسنجی..."):
133
+ val_performance = strategy.run(env_val)
134
+
135
+ with st.spinner("در حال اجرای بکتست نهایی..."):
136
+ test_performance = strategy.run(env_test)
137
+
138
+ st.subheader("📊 نتایج نهایی")
139
+ col1, col2, col3 = st.columns(3)
140
+ with col1:
141
+ st.metric("سود/زیان (Train)", f"{env_train.portfolio.performance.profit_loss:.2f}%")
142
+ with col2:
143
+ st.metric("سود/زیان (Val)", f"{val_performance['profit_loss']:.2f}%")
144
+ with col3:
145
+ st.metric("سود/زیان (Test)", f"{test_performance['profit_loss']:.2f}%")
146
+
147
+ st.subheader("📈 نمودار ارزش پرتفوی")
148
+ fig, ax = plt.subplots()
149
+ ax.plot(env_train.portfolio.performance.net_worth, label='آموزش')
150
+ ax.plot(val_performance['net_worth'], label='اعتبارسنجی')
151
+ ax.plot(test_performance['net_worth'], label='تست')
152
+ ax.legend()
153
+ st.pyplot(fig)
154
+
155
+ st.subheader("💾 ذخیره مدل")
156
+ with open("trained_model.zip", "rb") as f:
157
+ st.download_button(
158
+ label="دانلود مدل آموزش‌دیده",
159
+ data=f,
160
+ file_name="trading_model.zip",
161
+ mime="application/zip"
162
+ )
163
+
164
+ config = {
165
+ "hidden_layers": net_arch,
166
+ "activation": activation,
167
+ "learning_rate": learning_rate
168
+ }
169
+ st.download_button(
170
+ label="دانلود تنظیمات مدل",
171
+ data=json.dumps(config),
172
+ file_name="model_config.json",
173
+ mime="application/json"
174
+ )
175
+
176
+ except Exception as e:
177
+ st.error(f"❌ خطا: {str(e)}")