mgbam commited on
Commit
a8704ad
·
verified ·
1 Parent(s): 056e1f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -76
app.py CHANGED
@@ -1,16 +1,14 @@
1
  # app.py
2
  """
3
- Main application file for SHASHA AI, aGradio‑based AI code‑generation tool.
4
-
5
- Provides a UI for generating code in many languages using various AI models.
6
- Supports text prompts, file uploads, website scraping, optional web search,
7
- and live previews of HTML output.
8
  """
9
 
10
  import gradio as gr
11
  from typing import Optional, Dict, List, Tuple, Any
 
12
 
13
- # ─── Local module imports ────────────────────────────────────────────────
14
  from constants import (
15
  HTML_SYSTEM_PROMPT,
16
  TRANSFORMERS_JS_SYSTEM_PROMPT,
@@ -31,17 +29,16 @@ from utils import (
31
  )
32
  from deploy import send_to_sandbox
33
 
34
- # ─── Type aliases ───────────────────────────────────────────────────────
35
  History = List[Tuple[str, str]]
36
  Model = Dict[str, Any]
37
 
38
- # ─── Supported languages (dropdown) ─────────────────────────────────────
39
  SUPPORTED_LANGUAGES = [
40
- "python", "c", "cpp", "markdown", "latex", "json", "html", "css",
41
- "javascript", "jinja2", "typescript", "yaml", "dockerfile", "shell",
42
- "r", "sql", "sql-msSQL", "sql-mySQL", "sql-mariaDB", "sql-sqlite",
43
- "sql-cassandra", "sql-plSQL", "sql-hive", "sql-pgSQL", "sql-gql",
44
- "sql-gpSQL", "sql-sparkSQL", "sql-esper",
45
  ]
46
 
47
  def get_model_details(name: str) -> Optional[Model]:
@@ -50,7 +47,6 @@ def get_model_details(name: str) -> Optional[Model]:
50
  return m
51
  return None
52
 
53
- # ─── Core generation function ───────────────────────────────────────────
54
  def generation_code(
55
  query: Optional[str],
56
  file: Optional[str],
@@ -60,10 +56,10 @@ def generation_code(
60
  language: str,
61
  history: Optional[History],
62
  ) -> Tuple[str, History, str, List[Dict[str, str]]]:
63
- query = query or ""
64
- history = history or []
65
  try:
66
- # Choose system prompt based on language
67
  if language == "html":
68
  system_prompt = HTML_SYSTEM_PROMPT
69
  elif language == "transformers.js":
@@ -75,36 +71,28 @@ def generation_code(
75
  )
76
 
77
  model_id = current_model["id"]
78
- # Determine provider
79
- if model_id.startswith("openai/") or model_id in {"gpt-4", "gpt-3.5-turbo"}:
80
- provider = "openai"
81
- elif model_id.startswith("gemini/") or model_id.startswith("google/"):
82
- provider = "gemini"
83
- elif model_id.startswith("fireworks-ai/"):
84
- provider = "fireworks-ai"
85
- else:
86
- provider = "auto"
87
 
88
  # Build message history
89
  msgs = history_to_messages(history, system_prompt)
90
  context = query
91
  if file:
92
- ftext = extract_text_from_file(file)
93
- context += f"\n\n[Attached file]\n{ftext[:5000]}"
94
  if website_url:
95
  wtext = extract_website_content(website_url)
96
  if not wtext.startswith("Error"):
97
  context += f"\n\n[Website content]\n{wtext[:8000]}"
98
- final_q = enhance_query_with_search(context, enable_search)
99
- msgs.append({"role": "user", "content": final_q})
100
 
101
  # Call the model
102
- client = get_inference_client(model_id, provider)
103
- resp = client.chat.completions.create(
104
- model=model_id,
105
- messages=msgs,
106
- max_tokens=16000,
107
- temperature=0.1,
108
  )
109
  content = resp.choices[0].message.content
110
 
@@ -113,80 +101,72 @@ def generation_code(
113
  history.append((query, err))
114
  return "", history, "", history_to_chatbot_messages(history)
115
 
116
- # Process model output
117
  if language == "transformers.js":
118
  files = parse_transformers_js_output(content)
119
  code = format_transformers_js_output(files)
120
- preview = send_to_sandbox(files.get("index.html", ""))
121
  else:
122
  cleaned = remove_code_block(content)
123
- if history and history[-1][1] and not history[-1][1].startswith("❌"):
124
- code = apply_search_replace_changes(history[-1][1], cleaned)
125
- else:
126
- code = cleaned
127
  preview = send_to_sandbox(code) if language == "html" else ""
128
 
129
  new_hist = history + [(query, code)]
130
- chat = history_to_chatbot_messages(new_hist)
131
- return code, new_hist, preview, chat
132
 
133
- # ─── Custom CSS (added #logo rule) ───────────────────────────────────────
134
  CUSTOM_CSS = """
135
- body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; }
136
- #logo { display:block; margin:20px auto; max-height:80px; }
137
- #main_title{ text-align:center; font-size:2.5rem; margin-top:0.5rem; }
138
- #subtitle { text-align:center; color:#4a5568; margin-bottom:2.0rem; }
139
- .gradio-container { background-color:#f7fafc; }
140
- #gen_btn { box-shadow:0 4px 6px rgba(0,0,0,0.1); }
141
  """
142
 
143
- # ─── Gradio UI ───────────────────────────────────────────────────────────
 
 
144
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"),
145
  css=CUSTOM_CSS,
146
- title="Shasha AI") as demo:
147
-
148
  history_state = gr.State([])
149
  initial_model = AVAILABLE_MODELS[0]
150
  model_state = gr.State(initial_model)
151
 
152
- # Logo • Title • Subtitle
153
- gr.Image("assets/logo.png", elem_id="logo", show_label=False)
154
- gr.Markdown("# 🚀 Shasha AI", elem_id="main_title")
155
- gr.Markdown("Your AI partner for generating, modifying, and understanding code.",
156
- elem_id="subtitle")
 
 
157
 
158
  with gr.Row():
159
- # ── Left column (inputs)
160
  with gr.Column(scale=1):
161
- gr.Markdown("### 1. Select Model")
162
  model_dd = gr.Dropdown(
163
  choices=[m["name"] for m in AVAILABLE_MODELS],
164
  value=initial_model["name"],
165
- label="AI Model",
166
  )
167
 
168
- gr.Markdown("### 2. Provide Context")
169
  with gr.Tabs():
170
  with gr.Tab("📝 Prompt"):
171
- prompt_in = gr.Textbox(lines=7,
172
- placeholder="Describe your request…",
173
- show_label=False)
174
  with gr.Tab("📄 File"):
175
- file_in = gr.File(type="filepath")
176
  with gr.Tab("🌐 Website"):
177
- url_in = gr.Textbox(placeholder="https://example.com")
178
 
179
- gr.Markdown("### 3. Configure Output")
180
- lang_dd = gr.Dropdown(SUPPORTED_LANGUAGES,
181
- value="html",
182
- label="Target Language")
183
- search_chk = gr.Checkbox(label="Enable Web Search")
184
 
185
  with gr.Row():
186
  clr_btn = gr.Button("Clear Session", variant="secondary")
187
- gen_btn = gr.Button("Generate Code", variant="primary", elem_id="gen_btn")
188
 
189
- # ── Right column (outputs)
190
  with gr.Column(scale=2):
191
  with gr.Tabs():
192
  with gr.Tab("💻 Code"):
@@ -196,7 +176,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"),
196
  with gr.Tab("📜 History"):
197
  chat_out = gr.Chatbot(type="messages")
198
 
199
- # ── Callbacks
200
  model_dd.change(lambda n: get_model_details(n) or initial_model,
201
  inputs=[model_dd], outputs=[model_state])
202
 
 
1
  # app.py
2
  """
3
+ Main application file for SHASHA AI, a Gradio‑based AI code‑generation tool.
4
+ Adds a small banner logo (assets/logo.png) above the page title.
 
 
 
5
  """
6
 
7
  import gradio as gr
8
  from typing import Optional, Dict, List, Tuple, Any
9
+ import os
10
 
11
+ # --- Local module imports ---
12
  from constants import (
13
  HTML_SYSTEM_PROMPT,
14
  TRANSFORMERS_JS_SYSTEM_PROMPT,
 
29
  )
30
  from deploy import send_to_sandbox
31
 
32
+ # --- Type aliases ---
33
  History = List[Tuple[str, str]]
34
  Model = Dict[str, Any]
35
 
36
+ # --- Supported languages for dropdown ---
37
  SUPPORTED_LANGUAGES = [
38
+ "python","c","cpp","markdown","latex","json","html","css","javascript","jinja2",
39
+ "typescript","yaml","dockerfile","shell","r","sql","sql-msSQL","sql-mySQL",
40
+ "sql-mariaDB","sql-sqlite","sql-cassandra","sql-plSQL","sql-hive","sql-pgSQL",
41
+ "sql-gql","sql-gpSQL","sql-sparkSQL","sql-esper"
 
42
  ]
43
 
44
  def get_model_details(name: str) -> Optional[Model]:
 
47
  return m
48
  return None
49
 
 
50
  def generation_code(
51
  query: Optional[str],
52
  file: Optional[str],
 
56
  language: str,
57
  history: Optional[History],
58
  ) -> Tuple[str, History, str, List[Dict[str, str]]]:
59
+ query = query or ""
60
+ history = history or []
61
  try:
62
+ # Choose system prompt
63
  if language == "html":
64
  system_prompt = HTML_SYSTEM_PROMPT
65
  elif language == "transformers.js":
 
71
  )
72
 
73
  model_id = current_model["id"]
74
+ provider = (
75
+ "openai" if model_id.startswith("openai/") or model_id in {"gpt-4","gpt-3.5-turbo"}
76
+ else "gemini" if model_id.startswith(("gemini/","google/"))
77
+ else "fireworks-ai" if model_id.startswith("fireworks-ai/")
78
+ else "auto"
79
+ )
 
 
 
80
 
81
  # Build message history
82
  msgs = history_to_messages(history, system_prompt)
83
  context = query
84
  if file:
85
+ context += f"\n\n[Attached file]\n{extract_text_from_file(file)[:5000]}"
 
86
  if website_url:
87
  wtext = extract_website_content(website_url)
88
  if not wtext.startswith("Error"):
89
  context += f"\n\n[Website content]\n{wtext[:8000]}"
90
+ msgs.append({"role":"user","content":enhance_query_with_search(context, enable_search)})
 
91
 
92
  # Call the model
93
+ client = get_inference_client(model_id, provider)
94
+ resp = client.chat.completions.create(
95
+ model=model_id, messages=msgs, max_tokens=16000, temperature=0.1
 
 
 
96
  )
97
  content = resp.choices[0].message.content
98
 
 
101
  history.append((query, err))
102
  return "", history, "", history_to_chatbot_messages(history)
103
 
104
+ # Process output
105
  if language == "transformers.js":
106
  files = parse_transformers_js_output(content)
107
  code = format_transformers_js_output(files)
108
+ preview = send_to_sandbox(files.get("index.html",""))
109
  else:
110
  cleaned = remove_code_block(content)
111
+ code = apply_search_replace_changes(history[-1][1], cleaned) if history and not history[-1][1].startswith("❌") else cleaned
 
 
 
112
  preview = send_to_sandbox(code) if language == "html" else ""
113
 
114
  new_hist = history + [(query, code)]
115
+ return code, new_hist, preview, history_to_chatbot_messages(new_hist)
 
116
 
117
+ # --- Custom CSS ---
118
  CUSTOM_CSS = """
119
+ body{font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Roboto,sans-serif;}
120
+ #main_title{text-align:center;font-size:2.5rem;margin-top:0.5rem;}
121
+ #subtitle{text-align:center;color:#4a5568;margin-bottom:2rem;}
122
+ .gradio-container{background-color:#f7fafc;}
123
+ #gen_btn{box-shadow:0 4px 6px rgba(0,0,0,0.1);}
 
124
  """
125
 
126
+ LOGO_PATH = "assets/logo.png"
127
+ logo_exists = os.path.exists(LOGO_PATH)
128
+
129
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"),
130
  css=CUSTOM_CSS,
131
+ title="Shasha AI") as demo:
 
132
  history_state = gr.State([])
133
  initial_model = AVAILABLE_MODELS[0]
134
  model_state = gr.State(initial_model)
135
 
136
+ if logo_exists:
137
+ # Small banner logo (height ~70px) centred
138
+ gr.Image(value=LOGO_PATH, height=70, show_label=False,
139
+ container=False, elem_id="banner_logo")
140
+
141
+ gr.Markdown("# 🚀 Shasha AI", elem_id="main_title")
142
+ gr.Markdown("Your AI partner for generating, modifying, and understanding code.", elem_id="subtitle")
143
 
144
  with gr.Row():
 
145
  with gr.Column(scale=1):
146
+ gr.Markdown("### 1. Select Model")
147
  model_dd = gr.Dropdown(
148
  choices=[m["name"] for m in AVAILABLE_MODELS],
149
  value=initial_model["name"],
150
+ label="AI Model"
151
  )
152
 
153
+ gr.Markdown("### 2. Provide Context")
154
  with gr.Tabs():
155
  with gr.Tab("📝 Prompt"):
156
+ prompt_in = gr.Textbox(lines=7, placeholder="Describe your request...", show_label=False)
 
 
157
  with gr.Tab("📄 File"):
158
+ file_in = gr.File(type="filepath")
159
  with gr.Tab("🌐 Website"):
160
+ url_in = gr.Textbox(placeholder="https://example.com")
161
 
162
+ gr.Markdown("### 3. Configure Output")
163
+ lang_dd = gr.Dropdown(SUPPORTED_LANGUAGES, value="html", label="Target Language")
164
+ search_chk= gr.Checkbox(label="Enable Web Search")
 
 
165
 
166
  with gr.Row():
167
  clr_btn = gr.Button("Clear Session", variant="secondary")
168
+ gen_btn = gr.Button("Generate Code", variant="primary", elem_id="gen_btn")
169
 
 
170
  with gr.Column(scale=2):
171
  with gr.Tabs():
172
  with gr.Tab("💻 Code"):
 
176
  with gr.Tab("📜 History"):
177
  chat_out = gr.Chatbot(type="messages")
178
 
 
179
  model_dd.change(lambda n: get_model_details(n) or initial_model,
180
  inputs=[model_dd], outputs=[model_state])
181