mgbam commited on
Commit
f7cf3be
Β·
verified Β·
1 Parent(s): da49c48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -185
app.py CHANGED
@@ -1,201 +1,93 @@
1
- # app.py
2
 
3
- """
4
- Main application file for SHASHA AI, a Gradio-based AI code generation tool.
5
-
6
- Provides a UI for generating code in many languages using various AI models.
7
- Supports text prompts, file uploads, website scraping, optional web search,
8
- and live previews of HTML output.
9
- """
10
 
 
11
  import gradio as gr
12
- from typing import Optional, Dict, List, Tuple, Any
 
13
 
14
- # --- Local module imports ---
 
 
15
  from constants import (
16
  HTML_SYSTEM_PROMPT,
17
  TRANSFORMERS_JS_SYSTEM_PROMPT,
18
  AVAILABLE_MODELS,
19
- DEMO_LIST,
20
- )
21
- from hf_client import get_inference_client
22
- from tavily_search import enhance_query_with_search
23
- from utils import (
24
- extract_text_from_file,
25
- extract_website_content,
26
- apply_search_replace_changes,
27
- history_to_messages,
28
- history_to_chatbot_messages,
29
- remove_code_block,
30
- parse_transformers_js_output,
31
- format_transformers_js_output,
32
  )
33
- from deploy import send_to_sandbox
34
 
35
- # --- Type aliases ---
36
  History = List[Tuple[str, str]]
37
- Model = Dict[str, Any]
38
-
39
- # --- Supported languages for dropdown ---
40
- SUPPORTED_LANGUAGES = [
41
- "python", "c", "cpp", "markdown", "latex", "json", "html", "css",
42
- "javascript", "jinja2", "typescript", "yaml", "dockerfile", "shell",
43
- "r", "sql", "sql-msSQL", "sql-mySQL", "sql-mariaDB", "sql-sqlite",
44
- "sql-cassandra", "sql-plSQL", "sql-hive", "sql-pgSQL", "sql-gql",
45
- "sql-gpSQL", "sql-sparkSQL", "sql-esper"
46
- ]
47
-
48
- def get_model_details(name: str) -> Optional[Model]:
49
- for m in AVAILABLE_MODELS:
50
- if m["name"] == name:
51
- return m
52
- return None
53
-
54
- def generation_code(
55
- query: Optional[str],
56
- file: Optional[str],
57
- website_url: Optional[str],
58
- current_model: Model,
59
- enable_search: bool,
60
- language: str,
61
- history: Optional[History],
62
- ) -> Tuple[str, History, str, List[Dict[str, str]]]:
63
- query = query or ""
64
- history = history or []
65
- try:
66
- # Choose system prompt based on language
67
- if language == "html":
68
- system_prompt = HTML_SYSTEM_PROMPT
69
- elif language == "transformers.js":
70
- system_prompt = TRANSFORMERS_JS_SYSTEM_PROMPT
71
- else:
72
- # Generic fallback prompt
73
- system_prompt = (
74
- f"You are an expert {language} developer. "
75
- f"Write clean, idiomatic {language} code based on the user's request."
76
- )
77
-
78
- model_id = current_model["id"]
79
- # Determine provider
80
- if model_id.startswith("openai/") or model_id in {"gpt-4", "gpt-3.5-turbo"}:
81
- provider = "openai"
82
- elif model_id.startswith("gemini/") or model_id.startswith("google/"):
83
- provider = "gemini"
84
- elif model_id.startswith("fireworks-ai/"):
85
- provider = "fireworks-ai"
86
- else:
87
- provider = "auto"
88
-
89
- # Build message history
90
- msgs = history_to_messages(history, system_prompt)
91
- context = query
92
- if file:
93
- ftext = extract_text_from_file(file)
94
- context += f"\n\n[Attached file]\n{ftext[:5000]}"
95
- if website_url:
96
- wtext = extract_website_content(website_url)
97
- if not wtext.startswith("Error"):
98
- context += f"\n\n[Website content]\n{wtext[:8000]}"
99
- final_q = enhance_query_with_search(context, enable_search)
100
- msgs.append({"role": "user", "content": final_q})
101
-
102
- # Call the model
103
- client = get_inference_client(model_id, provider)
104
- resp = client.chat.completions.create(
105
- model=model_id,
106
- messages=msgs,
107
- max_tokens=16000,
108
- temperature=0.1
109
- )
110
- content = resp.choices[0].message.content
111
-
112
- except Exception as e:
113
- err = f"❌ **Error:**\n```\n{e}\n```"
114
- history.append((query, err))
115
- return "", history, "", history_to_chatbot_messages(history)
116
-
117
- # Process model output
118
- if language == "transformers.js":
119
- files = parse_transformers_js_output(content)
120
- code = format_transformers_js_output(files)
121
- preview = send_to_sandbox(files.get("index.html", ""))
122
- else:
123
- cleaned = remove_code_block(content)
124
- if history and history[-1][1] and not history[-1][1].startswith("❌"):
125
- code = apply_search_replace_changes(history[-1][1], cleaned)
126
- else:
127
- code = cleaned
128
- preview = send_to_sandbox(code) if language == "html" else ""
129
-
130
- new_hist = history + [(query, code)]
131
- chat = history_to_chatbot_messages(new_hist)
132
- return code, new_hist, preview, chat
133
-
134
- # --- Custom CSS ---
135
- CUSTOM_CSS = """
136
- body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; }
137
- #main_title { text-align: center; font-size: 2.5rem; margin-top: 1.5rem; }
138
- #subtitle { text-align: center; color: #4a5568; margin-bottom: 2.5rem; }
139
- .gradio-container { background-color: #f7fafc; }
140
- #gen_btn { box-shadow: 0 4px 6px rgba(0,0,0,0.1); }
141
- """
142
 
143
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), css=CUSTOM_CSS, title="Shasha AI") as demo:
144
- history_state = gr.State([])
145
- initial_model = AVAILABLE_MODELS[0]
146
- model_state = gr.State(initial_model)
147
-
148
- gr.Markdown("# πŸš€ Shasha AI", elem_id="main_title")
149
- gr.Markdown("Your AI partner for generating, modifying, and understanding code.", elem_id="subtitle")
150
-
151
- with gr.Row():
152
- with gr.Column(scale=1):
153
- gr.Markdown("### 1. Select Model")
154
- model_dd = gr.Dropdown(
155
- choices=[m["name"] for m in AVAILABLE_MODELS],
156
- value=initial_model["name"],
157
- label="AI Model"
158
- )
159
-
160
- gr.Markdown("### 2. Provide Context")
161
- with gr.Tabs():
162
- with gr.Tab("πŸ“ Prompt"):
163
- prompt_in = gr.Textbox(lines=7, placeholder="Describe your request...", show_label=False)
164
- with gr.Tab("πŸ“„ File"):
165
- file_in = gr.File(type="filepath")
166
- with gr.Tab("🌐 Website"):
167
- url_in = gr.Textbox(placeholder="https://example.com")
168
-
169
- gr.Markdown("### 3. Configure Output")
170
- lang_dd = gr.Dropdown(SUPPORTED_LANGUAGES, value="html", label="Target Language")
171
- search_chk = gr.Checkbox(label="Enable Web Search")
172
-
173
- with gr.Row():
174
- clr_btn = gr.Button("Clear Session", variant="secondary")
175
- gen_btn = gr.Button("Generate Code", variant="primary", elem_id="gen_btn")
176
-
177
- with gr.Column(scale=2):
178
- with gr.Tabs():
179
- with gr.Tab("πŸ’» Code"):
180
- code_out = gr.Code(language="html", interactive=True)
181
- with gr.Tab("πŸ‘οΈ Live Preview"):
182
- preview_out = gr.HTML()
183
- with gr.Tab("πŸ“œ History"):
184
- chat_out = gr.Chatbot(type="messages")
185
 
186
- model_dd.change(lambda n: get_model_details(n) or initial_model, inputs=[model_dd], outputs=[model_state])
 
 
 
 
 
187
 
188
- gen_btn.click(
189
- fn=generation_code,
190
- inputs=[prompt_in, file_in, url_in, model_state, search_chk, lang_dd, history_state],
191
- outputs=[code_out, history_state, preview_out, chat_out],
192
- )
193
 
194
- clr_btn.click(
195
- lambda: ("", None, "", [], "", "", []),
196
- outputs=[prompt_in, file_in, url_in, history_state, code_out, preview_out, chat_out],
197
- queue=False,
198
- )
 
199
 
200
- if __name__ == "__main__":
201
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py ── SHASHAΒ AI β€œHybrid” (FastAPIΒ +Β GradioΒ +Β Static UI)
2
 
3
+ from __future__ import annotations
4
+ from typing import Any, Dict, List, Optional, Tuple
 
 
 
 
 
5
 
6
+ import asyncio
7
  import gradio as gr
8
+ from fastapi import FastAPI
9
+ from fastapi.staticfiles import StaticFiles
10
 
11
+ # ────────────────────────────────────────────────────────────────
12
+ # internal helpers (unchanged)
13
+ # ────────────────────────────────────────────────────────────────
14
  from constants import (
15
  HTML_SYSTEM_PROMPT,
16
  TRANSFORMERS_JS_SYSTEM_PROMPT,
17
  AVAILABLE_MODELS,
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  )
19
+ from inference import generation_code
20
 
 
21
  History = List[Tuple[str, str]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # ────────────────────────────────────────────────────────────────
24
+ # 1.Β  Blocks‑only β€œheadless” API (no UI, just /api/predict JSON)
25
+ # ────────────────────────────────────────────────────────────────
26
+ with gr.Blocks(css="body{display:none}") as api_demo: # invisible
27
+ prompt_in = gr.Textbox()
28
+ file_in = gr.File()
29
+ url_in = gr.Textbox()
30
+ model_state = gr.State(AVAILABLE_MODELS[0])
31
+ search_chk = gr.Checkbox()
32
+ lang_dd = gr.Dropdown(choices=["html", "python"], value="html")
33
+ hist_state = gr.State([])
34
+
35
+ code_out = gr.Textbox() # plain JSON
36
+ hist_out = gr.State()
37
+ preview_out = gr.Textbox()
38
+ chat_out = gr.State()
39
+
40
+ api_demo.load(
41
+ generation_code,
42
+ inputs=[prompt_in, file_in, url_in, model_state,
43
+ search_chk, lang_dd, hist_state],
44
+ outputs=[code_out, hist_out, preview_out, chat_out],
45
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ # ────────────────────────────────────────────────────────────────
48
+ # 2.Β  Hybrid FastAPI server mounts:
49
+ # β€’ / β†’ static/ (index.html, style.css, index.js …)
50
+ # β€’ /api/* β†’ Gradio JSON (& websocket queue)
51
+ # ────────────────────────────────────────────────────────────────
52
+ app = FastAPI(title="SHASHAΒ AI hybrid server")
53
 
54
+ # static assets
55
+ app.mount(
56
+ "/", StaticFiles(directory="static", html=True), name="static"
57
+ )
 
58
 
59
+ # gradio API
60
+ app.mount(
61
+ "/api",
62
+ gr.mount_gradio_app(app, api_demo, path="/predict"), # POST /api/predict
63
+ name="gradio-api",
64
+ )
65
 
66
+ # ────────────────────────────────────────────────────────────────
67
+ # 3.Β  Bonus: Web‑socket streamer for lightning‑fast preview
68
+ # ────────────────────────────────────────────────────────────────
69
+ @app.websocket("/api/stream")
70
+ async def stream(websocket):
71
+ """
72
+ Front‑end connects, sends the same JSON as /predict,
73
+ and receives chunks (string tokens) as they arrive.
74
+ """
75
+ await websocket.accept()
76
+ payload = await websocket.receive_json()
77
+ queue: asyncio.Queue[str] = asyncio.Queue()
78
+
79
+ # spawn background generation
80
+ async def _run() -> None:
81
+ async for token in generation_code.stream(**payload): # type: ignore
82
+ await queue.put(token)
83
+ await queue.put("__END__")
84
+
85
+ asyncio.create_task(_run())
86
+
87
+ while True:
88
+ item = await queue.get()
89
+ if item == "__END__":
90
+ break
91
+ await websocket.send_text(item)
92
+
93
+ await websocket.close()