ArnabDeo's picture
Update app.py
2b43840 verified
raw
history blame
14.8 kB
import gradio as gr
import yfinance as yf
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import warnings
import time
import gc
import os
import torch
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, Tuple
warnings.filterwarnings('ignore')
# Environment optimizations for Hugging Face Spaces
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.set_num_threads(min(4, os.cpu_count() or 1))
class FastAIStockAnalyzer:
"""Optimized AI Stock Analyzer for Gradio"""
def __init__(self):
self.context_length = 32
self.prediction_length = 7
self.device = "cpu"
self.model_cache = {}
def fetch_stock_data(self, symbol: str, period: str = "6mo") -> Tuple[Optional[pd.DataFrame], Optional[Dict]]:
"""Fetch stock data with error handling"""
try:
ticker = yf.Ticker(symbol)
data = ticker.history(period=period, interval="1d",
actions=False, auto_adjust=True,
back_adjust=False, repair=False)
if data.empty:
return None, None
try:
info = {
'longName': ticker.info.get('longName', symbol),
'sector': ticker.info.get('sector', 'Unknown'),
'marketCap': ticker.info.get('marketCap', 0)
}
except:
info = {'longName': symbol, 'sector': 'Unknown', 'marketCap': 0}
return data, info
except Exception as e:
return None, None
def load_chronos_tiny(self) -> Tuple[Optional[Any], str]:
"""Load Chronos model with caching"""
model_key = "chronos_tiny"
if model_key in self.model_cache:
return self.model_cache[model_key], "chronos"
try:
from chronos import ChronosPipeline
pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-tiny",
device_map="cpu",
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
trust_remote_code=True
)
self.model_cache[model_key] = pipeline
return pipeline, "chronos"
except Exception as e:
return None, None
def load_moirai_small(self) -> Tuple[Optional[Any], str]:
"""Load Moirai model with caching"""
model_key = "moirai_small"
if model_key in self.model_cache:
return self.model_cache[model_key], "moirai"
try:
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule
module = MoiraiModule.from_pretrained(
"Salesforce/moirai-1.0-R-small",
low_cpu_mem_usage=True,
device_map="cpu",
torch_dtype=torch.float32,
trust_remote_code=True
)
model = MoiraiForecast(
module=module,
prediction_length=self.prediction_length,
context_length=self.context_length,
patch_size="auto",
num_samples=15,
target_dim=1,
feat_dynamic_real_dim=0,
past_feat_dynamic_real_dim=0
)
self.model_cache[model_key] = model
return model, "moirai"
except Exception as e:
return None, None
def predict_chronos_fast(self, pipeline: Any, data: np.ndarray) -> Optional[Dict]:
"""Fast Chronos prediction"""
try:
context_data = data[-self.context_length:]
context = torch.tensor(context_data, dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
forecast = pipeline.predict(
context=context,
prediction_length=self.prediction_length,
num_samples=10,
temperature=1.0,
top_k=50,
top_p=1.0
)
forecast_array = forecast[0].numpy()
predictions = {
'mean': np.median(forecast_array, axis=0),
'q10': np.quantile(forecast_array, 0.1, axis=0),
'q90': np.quantile(forecast_array, 0.9, axis=0),
'std': np.std(forecast_array, axis=0)
}
return predictions
except Exception as e:
return None
def predict_moirai_fast(self, model: Any, data: np.ndarray) -> Optional[Dict]:
"""Fast Moirai prediction"""
try:
from gluonts.dataset.common import ListDataset
dataset = ListDataset([{
"item_id": "stock",
"start": "2023-01-01",
"target": data[-self.context_length:].tolist()
}], freq='D')
predictor = model.create_predictor(
batch_size=1,
num_parallel_samples=10
)
forecasts = list(predictor.predict(dataset))
forecast = forecasts[0]
predictions = {
'mean': forecast.mean,
'q10': forecast.quantile(0.1),
'q90': forecast.quantile(0.9),
'std': np.std(forecast.samples, axis=0)
}
return predictions
except Exception as e:
return None
# Initialize analyzer globally for caching
analyzer = FastAIStockAnalyzer()
def analyze_stock(stock_symbol, model_choice, investment_amount, progress=gr.Progress()):
"""Main analysis function for Gradio"""
progress(0.1, desc="Fetching stock data...")
# Fetch data
stock_data, stock_info = analyzer.fetch_stock_data(stock_symbol)
if stock_data is None or len(stock_data) < 50:
return (
"❌ Error: Insufficient data for analysis. Please check the stock symbol.",
None,
None,
"N/A",
"N/A"
)
current_price = stock_data['Close'].iloc[-1]
company_name = stock_info.get('longName', stock_symbol) if stock_info else stock_symbol
progress(0.3, desc="Loading AI model...")
# Load model
model_type = "chronos" if "Chronos" in model_choice else "moirai"
if model_type == "chronos":
model, loaded_type = analyzer.load_chronos_tiny()
model_name = "Amazon Chronos Tiny"
else:
model, loaded_type = analyzer.load_moirai_small()
model_name = "Salesforce Moirai Small"
if model is None:
return (
"❌ Error: Failed to load AI model. Please try again.",
None,
None,
"N/A",
"N/A"
)
progress(0.6, desc="Generating AI predictions...")
# Generate predictions
if model_type == "chronos":
predictions = analyzer.predict_chronos_fast(model, stock_data['Close'].values)
else:
predictions = analyzer.predict_moirai_fast(model, stock_data['Close'].values)
if predictions is None:
return (
"❌ Error: Prediction failed. Please try again.",
None,
None,
"N/A",
"N/A"
)
progress(0.8, desc="Calculating investment scenarios...")
# Analysis results
mean_pred = predictions['mean']
final_pred = mean_pred[-1]
week_change = ((final_pred - current_price) / current_price) * 100
# Decision logic
if week_change > 5:
decision = "🟒 STRONG BUY"
explanation = "AI expects significant gains!"
elif week_change > 2:
decision = "🟒 BUY"
explanation = "AI expects moderate gains"
elif week_change < -5:
decision = "πŸ”΄ STRONG SELL"
explanation = "AI expects significant losses"
elif week_change < -2:
decision = "πŸ”΄ SELL"
explanation = "AI expects losses"
else:
decision = "βšͺ HOLD"
explanation = "AI expects stable prices"
# Create analysis text
analysis_text = f"""
# 🎯 {company_name} ({stock_symbol}) Analysis
## πŸ€– AI RECOMMENDATION: {decision}
**{explanation}**
*Powered by {model_name}*
## πŸ“Š Key Metrics
- **Current Price**: ${current_price:.2f}
- **7-Day Prediction**: ${final_pred:.2f} ({week_change:+.2f}%)
- **AI Confidence**: {min(100, max(50, 70 + abs(week_change) * 1.5)):.0f}%
- **Model Used**: {model_name}
## πŸ’° Investment Scenario (${investment_amount:,.0f})
- **Shares**: {investment_amount/current_price:.2f}
- **Predicted Value**: ${investment_amount + ((final_pred - current_price) * (investment_amount/current_price)):,.2f}
- **Profit/Loss**: ${((final_pred - current_price) * (investment_amount/current_price)):+,.2f} ({week_change:+.2f}%)
⚠️ **DISCLAIMER**: This is AI-generated analysis for educational purposes only. Not financial advice.
"""
progress(0.9, desc="Creating charts...")
# Create chart
fig = go.Figure()
# Historical data (last 30 days)
recent = stock_data.tail(30)
fig.add_trace(go.Scatter(
x=recent.index,
y=recent['Close'],
mode='lines',
name='Historical Price',
line=dict(color='blue', width=2)
))
# Predictions
future_dates = pd.date_range(
start=stock_data.index[-1] + pd.Timedelta(days=1),
periods=7,
freq='D'
)
fig.add_trace(go.Scatter(
x=future_dates,
y=mean_pred,
mode='lines+markers',
name='AI Prediction',
line=dict(color='red', width=3),
marker=dict(size=8)
))
# Confidence bands
if 'q10' in predictions and 'q90' in predictions:
fig.add_trace(go.Scatter(
x=future_dates.tolist() + future_dates[::-1].tolist(),
y=predictions['q90'].tolist() + predictions['q10'][::-1].tolist(),
fill='toself',
fillcolor='rgba(255,0,0,0.1)',
line=dict(color='rgba(255,255,255,0)'),
name='Confidence Range',
showlegend=True
))
fig.update_layout(
title=f"{stock_symbol} - AI Stock Forecast",
xaxis_title="Date",
yaxis_title="Price ($)",
height=500,
showlegend=True,
template="plotly_white"
)
progress(1.0, desc="Analysis complete!")
# Create summary metrics
day_change = stock_data['Close'].iloc[-1] - stock_data['Close'].iloc[-2]
day_change_pct = (day_change / stock_data['Close'].iloc[-2]) * 100
current_metrics = f"${current_price:.2f} ({day_change_pct:+.2f}%)"
prediction_metrics = f"${final_pred:.2f} ({week_change:+.2f}%)"
return (
analysis_text,
fig,
decision,
current_metrics,
prediction_metrics
)
# Create Gradio interface[1][2][9][12]
with gr.Blocks(
theme=gr.themes.Soft(),
title="⚑ Fast AI Stock Predictor",
css="footer {visibility: hidden}"
) as demo:
gr.HTML("""
<div style="text-align: center; padding: 20px;">
<h1>⚑ Fast AI Stock Predictor</h1>
<p><strong>πŸ€– Powered by Amazon Chronos & Salesforce Moirai</strong></p>
<p style="color: #666; font-size: 14px;">⚠️ Educational use only - Not financial advice</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.HTML("<h3>🎯 Configuration</h3>")
stock_input = gr.Dropdown(
choices=["AAPL", "GOOGL", "MSFT", "TSLA", "AMZN", "META", "NFLX", "NVDA"],
value="AAPL",
label="Select Stock",
allow_custom_value=True,
info="Choose from popular stocks or enter custom symbol"
)
model_input = gr.Radio(
choices=["πŸš€ Chronos (Fast)", "🎯 Moirai (Accurate)"],
value="πŸš€ Chronos (Fast)",
label="AI Model",
info="Chronos: Faster | Moirai: More accurate"
)
investment_input = gr.Slider(
minimum=500,
maximum=50000,
value=5000,
step=500,
label="Investment Amount ($)",
info="Amount to analyze for profit/loss scenarios"
)
analyze_btn = gr.Button(
"πŸš€ Analyze Stock",
variant="primary",
size="lg"
)
with gr.Column(scale=2):
gr.HTML("<h3>πŸ“Š Results</h3>")
with gr.Row():
current_price_display = gr.Textbox(
label="Current Price",
interactive=False,
container=True
)
prediction_display = gr.Textbox(
label="7-Day Prediction",
interactive=False,
container=True
)
decision_display = gr.Textbox(
label="AI Decision",
interactive=False,
container=True
)
with gr.Row():
analysis_output = gr.Markdown(
label="Analysis Report",
value="Click 'Analyze Stock' to generate AI-powered analysis..."
)
with gr.Row():
chart_output = gr.Plot(
label="Price Chart & Predictions",
container=True
)
# Event handlers
analyze_btn.click(
fn=analyze_stock,
inputs=[stock_input, model_input, investment_input],
outputs=[
analysis_output,
chart_output,
decision_display,
current_price_display,
prediction_display
]
)
# Examples
gr.Examples(
examples=[
["AAPL", "πŸš€ Chronos (Fast)", 5000],
["TSLA", "🎯 Moirai (Accurate)", 10000],
["GOOGL", "πŸš€ Chronos (Fast)", 2500],
],
inputs=[stock_input, model_input, investment_input],
label="Try these examples:"
)
if __name__ == "__main__":
demo.launch(share=True)