Spaces:
Sleeping
Sleeping
import gradio as gr | |
import yfinance as yf | |
import pandas as pd | |
import numpy as np | |
import plotly.graph_objects as go | |
import warnings | |
import time | |
import gc | |
import os | |
import torch | |
from datetime import datetime, timedelta | |
from typing import Optional, Dict, Any, Tuple | |
warnings.filterwarnings('ignore') | |
# Environment optimizations for Hugging Face Spaces | |
os.environ['TOKENIZERS_PARALLELISM'] = 'false' | |
os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1' | |
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' | |
os.environ['TRANSFORMERS_VERBOSITY'] = 'error' | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.set_num_threads(min(4, os.cpu_count() or 1)) | |
class FastAIStockAnalyzer: | |
"""Optimized AI Stock Analyzer for Gradio""" | |
def __init__(self): | |
self.context_length = 32 | |
self.prediction_length = 7 | |
self.device = "cpu" | |
self.model_cache = {} | |
def fetch_stock_data(self, symbol: str, period: str = "6mo") -> Tuple[Optional[pd.DataFrame], Optional[Dict]]: | |
"""Fetch stock data with error handling""" | |
try: | |
ticker = yf.Ticker(symbol) | |
data = ticker.history(period=period, interval="1d", | |
actions=False, auto_adjust=True, | |
back_adjust=False, repair=False) | |
if data.empty: | |
return None, None | |
try: | |
info = { | |
'longName': ticker.info.get('longName', symbol), | |
'sector': ticker.info.get('sector', 'Unknown'), | |
'marketCap': ticker.info.get('marketCap', 0) | |
} | |
except: | |
info = {'longName': symbol, 'sector': 'Unknown', 'marketCap': 0} | |
return data, info | |
except Exception as e: | |
return None, None | |
def load_chronos_tiny(self) -> Tuple[Optional[Any], str]: | |
"""Load Chronos model with caching""" | |
model_key = "chronos_tiny" | |
if model_key in self.model_cache: | |
return self.model_cache[model_key], "chronos" | |
try: | |
from chronos import ChronosPipeline | |
pipeline = ChronosPipeline.from_pretrained( | |
"amazon/chronos-t5-tiny", | |
device_map="cpu", | |
torch_dtype=torch.float32, | |
low_cpu_mem_usage=True, | |
trust_remote_code=True | |
) | |
self.model_cache[model_key] = pipeline | |
return pipeline, "chronos" | |
except Exception as e: | |
return None, None | |
def load_moirai_small(self) -> Tuple[Optional[Any], str]: | |
"""Load Moirai model with caching""" | |
model_key = "moirai_small" | |
if model_key in self.model_cache: | |
return self.model_cache[model_key], "moirai" | |
try: | |
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule | |
module = MoiraiModule.from_pretrained( | |
"Salesforce/moirai-1.0-R-small", | |
low_cpu_mem_usage=True, | |
device_map="cpu", | |
torch_dtype=torch.float32, | |
trust_remote_code=True | |
) | |
model = MoiraiForecast( | |
module=module, | |
prediction_length=self.prediction_length, | |
context_length=self.context_length, | |
patch_size="auto", | |
num_samples=15, | |
target_dim=1, | |
feat_dynamic_real_dim=0, | |
past_feat_dynamic_real_dim=0 | |
) | |
self.model_cache[model_key] = model | |
return model, "moirai" | |
except Exception as e: | |
return None, None | |
def predict_chronos_fast(self, pipeline: Any, data: np.ndarray) -> Optional[Dict]: | |
"""Fast Chronos prediction""" | |
try: | |
context_data = data[-self.context_length:] | |
context = torch.tensor(context_data, dtype=torch.float32).unsqueeze(0) | |
with torch.no_grad(): | |
forecast = pipeline.predict( | |
context=context, | |
prediction_length=self.prediction_length, | |
num_samples=10, | |
temperature=1.0, | |
top_k=50, | |
top_p=1.0 | |
) | |
forecast_array = forecast[0].numpy() | |
predictions = { | |
'mean': np.median(forecast_array, axis=0), | |
'q10': np.quantile(forecast_array, 0.1, axis=0), | |
'q90': np.quantile(forecast_array, 0.9, axis=0), | |
'std': np.std(forecast_array, axis=0) | |
} | |
return predictions | |
except Exception as e: | |
return None | |
def predict_moirai_fast(self, model: Any, data: np.ndarray) -> Optional[Dict]: | |
"""Fast Moirai prediction""" | |
try: | |
from gluonts.dataset.common import ListDataset | |
dataset = ListDataset([{ | |
"item_id": "stock", | |
"start": "2023-01-01", | |
"target": data[-self.context_length:].tolist() | |
}], freq='D') | |
predictor = model.create_predictor( | |
batch_size=1, | |
num_parallel_samples=10 | |
) | |
forecasts = list(predictor.predict(dataset)) | |
forecast = forecasts[0] | |
predictions = { | |
'mean': forecast.mean, | |
'q10': forecast.quantile(0.1), | |
'q90': forecast.quantile(0.9), | |
'std': np.std(forecast.samples, axis=0) | |
} | |
return predictions | |
except Exception as e: | |
return None | |
# Initialize analyzer globally for caching | |
analyzer = FastAIStockAnalyzer() | |
def analyze_stock(stock_symbol, model_choice, investment_amount, progress=gr.Progress()): | |
"""Main analysis function for Gradio""" | |
progress(0.1, desc="Fetching stock data...") | |
# Fetch data | |
stock_data, stock_info = analyzer.fetch_stock_data(stock_symbol) | |
if stock_data is None or len(stock_data) < 50: | |
return ( | |
"β Error: Insufficient data for analysis. Please check the stock symbol.", | |
None, | |
None, | |
"N/A", | |
"N/A" | |
) | |
current_price = stock_data['Close'].iloc[-1] | |
company_name = stock_info.get('longName', stock_symbol) if stock_info else stock_symbol | |
progress(0.3, desc="Loading AI model...") | |
# Load model | |
model_type = "chronos" if "Chronos" in model_choice else "moirai" | |
if model_type == "chronos": | |
model, loaded_type = analyzer.load_chronos_tiny() | |
model_name = "Amazon Chronos Tiny" | |
else: | |
model, loaded_type = analyzer.load_moirai_small() | |
model_name = "Salesforce Moirai Small" | |
if model is None: | |
return ( | |
"β Error: Failed to load AI model. Please try again.", | |
None, | |
None, | |
"N/A", | |
"N/A" | |
) | |
progress(0.6, desc="Generating AI predictions...") | |
# Generate predictions | |
if model_type == "chronos": | |
predictions = analyzer.predict_chronos_fast(model, stock_data['Close'].values) | |
else: | |
predictions = analyzer.predict_moirai_fast(model, stock_data['Close'].values) | |
if predictions is None: | |
return ( | |
"β Error: Prediction failed. Please try again.", | |
None, | |
None, | |
"N/A", | |
"N/A" | |
) | |
progress(0.8, desc="Calculating investment scenarios...") | |
# Analysis results | |
mean_pred = predictions['mean'] | |
final_pred = mean_pred[-1] | |
week_change = ((final_pred - current_price) / current_price) * 100 | |
# Decision logic | |
if week_change > 5: | |
decision = "π’ STRONG BUY" | |
explanation = "AI expects significant gains!" | |
elif week_change > 2: | |
decision = "π’ BUY" | |
explanation = "AI expects moderate gains" | |
elif week_change < -5: | |
decision = "π΄ STRONG SELL" | |
explanation = "AI expects significant losses" | |
elif week_change < -2: | |
decision = "π΄ SELL" | |
explanation = "AI expects losses" | |
else: | |
decision = "βͺ HOLD" | |
explanation = "AI expects stable prices" | |
# Create analysis text | |
analysis_text = f""" | |
# π― {company_name} ({stock_symbol}) Analysis | |
## π€ AI RECOMMENDATION: {decision} | |
**{explanation}** | |
*Powered by {model_name}* | |
## π Key Metrics | |
- **Current Price**: ${current_price:.2f} | |
- **7-Day Prediction**: ${final_pred:.2f} ({week_change:+.2f}%) | |
- **AI Confidence**: {min(100, max(50, 70 + abs(week_change) * 1.5)):.0f}% | |
- **Model Used**: {model_name} | |
## π° Investment Scenario (${investment_amount:,.0f}) | |
- **Shares**: {investment_amount/current_price:.2f} | |
- **Predicted Value**: ${investment_amount + ((final_pred - current_price) * (investment_amount/current_price)):,.2f} | |
- **Profit/Loss**: ${((final_pred - current_price) * (investment_amount/current_price)):+,.2f} ({week_change:+.2f}%) | |
β οΈ **DISCLAIMER**: This is AI-generated analysis for educational purposes only. Not financial advice. | |
""" | |
progress(0.9, desc="Creating charts...") | |
# Create chart | |
fig = go.Figure() | |
# Historical data (last 30 days) | |
recent = stock_data.tail(30) | |
fig.add_trace(go.Scatter( | |
x=recent.index, | |
y=recent['Close'], | |
mode='lines', | |
name='Historical Price', | |
line=dict(color='blue', width=2) | |
)) | |
# Predictions | |
future_dates = pd.date_range( | |
start=stock_data.index[-1] + pd.Timedelta(days=1), | |
periods=7, | |
freq='D' | |
) | |
fig.add_trace(go.Scatter( | |
x=future_dates, | |
y=mean_pred, | |
mode='lines+markers', | |
name='AI Prediction', | |
line=dict(color='red', width=3), | |
marker=dict(size=8) | |
)) | |
# Confidence bands | |
if 'q10' in predictions and 'q90' in predictions: | |
fig.add_trace(go.Scatter( | |
x=future_dates.tolist() + future_dates[::-1].tolist(), | |
y=predictions['q90'].tolist() + predictions['q10'][::-1].tolist(), | |
fill='toself', | |
fillcolor='rgba(255,0,0,0.1)', | |
line=dict(color='rgba(255,255,255,0)'), | |
name='Confidence Range', | |
showlegend=True | |
)) | |
fig.update_layout( | |
title=f"{stock_symbol} - AI Stock Forecast", | |
xaxis_title="Date", | |
yaxis_title="Price ($)", | |
height=500, | |
showlegend=True, | |
template="plotly_white" | |
) | |
progress(1.0, desc="Analysis complete!") | |
# Create summary metrics | |
day_change = stock_data['Close'].iloc[-1] - stock_data['Close'].iloc[-2] | |
day_change_pct = (day_change / stock_data['Close'].iloc[-2]) * 100 | |
current_metrics = f"${current_price:.2f} ({day_change_pct:+.2f}%)" | |
prediction_metrics = f"${final_pred:.2f} ({week_change:+.2f}%)" | |
return ( | |
analysis_text, | |
fig, | |
decision, | |
current_metrics, | |
prediction_metrics | |
) | |
# Create Gradio interface[1][2][9][12] | |
with gr.Blocks( | |
theme=gr.themes.Soft(), | |
title="β‘ Fast AI Stock Predictor", | |
css="footer {visibility: hidden}" | |
) as demo: | |
gr.HTML(""" | |
<div style="text-align: center; padding: 20px;"> | |
<h1>β‘ Fast AI Stock Predictor</h1> | |
<p><strong>π€ Powered by Amazon Chronos & Salesforce Moirai</strong></p> | |
<p style="color: #666; font-size: 14px;">β οΈ Educational use only - Not financial advice</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.HTML("<h3>π― Configuration</h3>") | |
stock_input = gr.Dropdown( | |
choices=["AAPL", "GOOGL", "MSFT", "TSLA", "AMZN", "META", "NFLX", "NVDA"], | |
value="AAPL", | |
label="Select Stock", | |
allow_custom_value=True, | |
info="Choose from popular stocks or enter custom symbol" | |
) | |
model_input = gr.Radio( | |
choices=["π Chronos (Fast)", "π― Moirai (Accurate)"], | |
value="π Chronos (Fast)", | |
label="AI Model", | |
info="Chronos: Faster | Moirai: More accurate" | |
) | |
investment_input = gr.Slider( | |
minimum=500, | |
maximum=50000, | |
value=5000, | |
step=500, | |
label="Investment Amount ($)", | |
info="Amount to analyze for profit/loss scenarios" | |
) | |
analyze_btn = gr.Button( | |
"π Analyze Stock", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Column(scale=2): | |
gr.HTML("<h3>π Results</h3>") | |
with gr.Row(): | |
current_price_display = gr.Textbox( | |
label="Current Price", | |
interactive=False, | |
container=True | |
) | |
prediction_display = gr.Textbox( | |
label="7-Day Prediction", | |
interactive=False, | |
container=True | |
) | |
decision_display = gr.Textbox( | |
label="AI Decision", | |
interactive=False, | |
container=True | |
) | |
with gr.Row(): | |
analysis_output = gr.Markdown( | |
label="Analysis Report", | |
value="Click 'Analyze Stock' to generate AI-powered analysis..." | |
) | |
with gr.Row(): | |
chart_output = gr.Plot( | |
label="Price Chart & Predictions", | |
container=True | |
) | |
# Event handlers | |
analyze_btn.click( | |
fn=analyze_stock, | |
inputs=[stock_input, model_input, investment_input], | |
outputs=[ | |
analysis_output, | |
chart_output, | |
decision_display, | |
current_price_display, | |
prediction_display | |
] | |
) | |
# Examples | |
gr.Examples( | |
examples=[ | |
["AAPL", "π Chronos (Fast)", 5000], | |
["TSLA", "π― Moirai (Accurate)", 10000], | |
["GOOGL", "π Chronos (Fast)", 2500], | |
], | |
inputs=[stock_input, model_input, investment_input], | |
label="Try these examples:" | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |