# app.py ── SHASHA AI “Hybrid” (FastAPI + Gradio + Static UI) from __future__ import annotations from typing import Any, Dict, List, Optional, Tuple import asyncio import gradio as gr from fastapi import FastAPI from fastapi.staticfiles import StaticFiles # ──────────────────────────────────────────────────────────────── # internal helpers (unchanged) # ──────────────────────────────────────────────────────────────── from constants import ( HTML_SYSTEM_PROMPT, TRANSFORMERS_JS_SYSTEM_PROMPT, AVAILABLE_MODELS, ) from inference import generation_code History = List[Tuple[str, str]] # ──────────────────────────────────────────────────────────────── # 1.  Blocks‑only “headless” API (no UI, just /api/predict JSON) # ──────────────────────────────────────────────────────────────── with gr.Blocks(css="body{display:none}") as api_demo: # invisible prompt_in = gr.Textbox() file_in = gr.File() url_in = gr.Textbox() model_state = gr.State(AVAILABLE_MODELS[0]) search_chk = gr.Checkbox() lang_dd = gr.Dropdown(choices=["html", "python"], value="html") hist_state = gr.State([]) code_out = gr.Textbox() # plain JSON hist_out = gr.State() preview_out = gr.Textbox() chat_out = gr.State() api_demo.load( generation_code, inputs=[prompt_in, file_in, url_in, model_state, search_chk, lang_dd, hist_state], outputs=[code_out, hist_out, preview_out, chat_out], ) # ──────────────────────────────────────────────────────────────── # 2.  Hybrid FastAPI server mounts: # • / → static/ (index.html, style.css, index.js …) # • /api/* → Gradio JSON (& websocket queue) # ──────────────────────────────────────────────────────────────── app = FastAPI(title="SHASHA AI hybrid server") # static assets app.mount( "/", StaticFiles(directory="static", html=True), name="static" ) # gradio API app.mount( "/api", gr.mount_gradio_app(app, api_demo, path="/predict"), # POST /api/predict name="gradio-api", ) # ──────────────────────────────────────────────────────────────── # 3.  Bonus: Web‑socket streamer for lightning‑fast preview # ──────────────────────────────────────────────────────────────── @app.websocket("/api/stream") async def stream(websocket): """ Front‑end connects, sends the same JSON as /predict, and receives chunks (string tokens) as they arrive. """ await websocket.accept() payload = await websocket.receive_json() queue: asyncio.Queue[str] = asyncio.Queue() # spawn background generation async def _run() -> None: async for token in generation_code.stream(**payload): # type: ignore await queue.put(token) await queue.put("__END__") asyncio.create_task(_run()) while True: item = await queue.get() if item == "__END__": break await websocket.send_text(item) await websocket.close()