mgbam commited on
Commit
9ab6f28
·
verified ·
1 Parent(s): a8704ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -55
app.py CHANGED
@@ -1,7 +1,8 @@
1
  # app.py
2
  """
3
- Main application file for SHASHA AI, a Gradio‑based AI code‑generation tool.
4
- Adds a small banner logo (assets/logo.png) above the page title.
 
5
  """
6
 
7
  import gradio as gr
@@ -29,11 +30,10 @@ from utils import (
29
  )
30
  from deploy import send_to_sandbox
31
 
32
- # --- Type aliases ---
33
  History = List[Tuple[str, str]]
34
  Model = Dict[str, Any]
35
 
36
- # --- Supported languages for dropdown ---
37
  SUPPORTED_LANGUAGES = [
38
  "python","c","cpp","markdown","latex","json","html","css","javascript","jinja2",
39
  "typescript","yaml","dockerfile","shell","r","sql","sql-msSQL","sql-mySQL",
@@ -41,59 +41,45 @@ SUPPORTED_LANGUAGES = [
41
  "sql-gql","sql-gpSQL","sql-sparkSQL","sql-esper"
42
  ]
43
 
44
- def get_model_details(name: str) -> Optional[Model]:
45
- for m in AVAILABLE_MODELS:
46
- if m["name"] == name:
47
- return m
48
- return None
49
 
50
  def generation_code(
51
- query: Optional[str],
52
- file: Optional[str],
53
- website_url: Optional[str],
54
- current_model: Model,
55
- enable_search: bool,
56
- language: str,
57
- history: Optional[History],
58
- ) -> Tuple[str, History, str, List[Dict[str, str]]]:
59
  query = query or ""
60
  history = history or []
61
  try:
62
- # Choose system prompt
63
- if language == "html":
64
- system_prompt = HTML_SYSTEM_PROMPT
65
- elif language == "transformers.js":
66
- system_prompt = TRANSFORMERS_JS_SYSTEM_PROMPT
67
- else:
68
- system_prompt = (
69
- f"You are an expert {language} developer. "
70
- f"Write clean, idiomatic {language} code based on the user's request."
71
- )
72
-
73
  model_id = current_model["id"]
74
  provider = (
75
- "openai" if model_id.startswith("openai/") or model_id in {"gpt-4","gpt-3.5-turbo"}
76
- else "gemini" if model_id.startswith(("gemini/","google/"))
77
  else "fireworks-ai" if model_id.startswith("fireworks-ai/")
78
  else "auto"
79
  )
80
 
81
- # Build message history
82
  msgs = history_to_messages(history, system_prompt)
83
  context = query
84
- if file:
85
- context += f"\n\n[Attached file]\n{extract_text_from_file(file)[:5000]}"
86
  if website_url:
87
  wtext = extract_website_content(website_url)
88
  if not wtext.startswith("Error"):
89
  context += f"\n\n[Website content]\n{wtext[:8000]}"
90
  msgs.append({"role":"user","content":enhance_query_with_search(context, enable_search)})
91
 
92
- # Call the model
93
  client = get_inference_client(model_id, provider)
94
- resp = client.chat.completions.create(
95
- model=model_id, messages=msgs, max_tokens=16000, temperature=0.1
96
- )
97
  content = resp.choices[0].message.content
98
 
99
  except Exception as e:
@@ -101,30 +87,28 @@ def generation_code(
101
  history.append((query, err))
102
  return "", history, "", history_to_chatbot_messages(history)
103
 
104
- # Process output
105
- if language == "transformers.js":
106
  files = parse_transformers_js_output(content)
107
  code = format_transformers_js_output(files)
108
  preview = send_to_sandbox(files.get("index.html",""))
109
  else:
110
  cleaned = remove_code_block(content)
111
  code = apply_search_replace_changes(history[-1][1], cleaned) if history and not history[-1][1].startswith("❌") else cleaned
112
- preview = send_to_sandbox(code) if language == "html" else ""
113
 
114
  new_hist = history + [(query, code)]
115
  return code, new_hist, preview, history_to_chatbot_messages(new_hist)
116
 
117
- # --- Custom CSS ---
118
  CUSTOM_CSS = """
119
  body{font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Roboto,sans-serif;}
120
- #main_title{text-align:center;font-size:2.5rem;margin-top:0.5rem;}
121
  #subtitle{text-align:center;color:#4a5568;margin-bottom:2rem;}
122
  .gradio-container{background-color:#f7fafc;}
123
  #gen_btn{box-shadow:0 4px 6px rgba(0,0,0,0.1);}
124
  """
125
 
126
  LOGO_PATH = "assets/logo.png"
127
- logo_exists = os.path.exists(LOGO_PATH)
128
 
129
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"),
130
  css=CUSTOM_CSS,
@@ -133,10 +117,10 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"),
133
  initial_model = AVAILABLE_MODELS[0]
134
  model_state = gr.State(initial_model)
135
 
136
- if logo_exists:
137
- # Small banner logo (height ~70px) centred
138
- gr.Image(value=LOGO_PATH, height=70, show_label=False,
139
- container=False, elem_id="banner_logo")
140
 
141
  gr.Markdown("# 🚀 Shasha AI", elem_id="main_title")
142
  gr.Markdown("Your AI partner for generating, modifying, and understanding code.", elem_id="subtitle")
@@ -144,11 +128,8 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"),
144
  with gr.Row():
145
  with gr.Column(scale=1):
146
  gr.Markdown("### 1. Select Model")
147
- model_dd = gr.Dropdown(
148
- choices=[m["name"] for m in AVAILABLE_MODELS],
149
- value=initial_model["name"],
150
- label="AI Model"
151
- )
152
 
153
  gr.Markdown("### 2. Provide Context")
154
  with gr.Tabs():
@@ -160,9 +141,8 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"),
160
  url_in = gr.Textbox(placeholder="https://example.com")
161
 
162
  gr.Markdown("### 3. Configure Output")
163
- lang_dd = gr.Dropdown(SUPPORTED_LANGUAGES, value="html", label="Target Language")
164
- search_chk= gr.Checkbox(label="Enable Web Search")
165
-
166
  with gr.Row():
167
  clr_btn = gr.Button("Clear Session", variant="secondary")
168
  gen_btn = gr.Button("Generate Code", variant="primary", elem_id="gen_btn")
 
1
  # app.py
2
  """
3
+ Main application file for SHASHA AI, a Gradio‑based AI code‑generation tool.
4
+
5
+ Only change: reduce logo width so the banner isn’t full‑width.
6
  """
7
 
8
  import gradio as gr
 
30
  )
31
  from deploy import send_to_sandbox
32
 
33
+ # --- Aliases ---
34
  History = List[Tuple[str, str]]
35
  Model = Dict[str, Any]
36
 
 
37
  SUPPORTED_LANGUAGES = [
38
  "python","c","cpp","markdown","latex","json","html","css","javascript","jinja2",
39
  "typescript","yaml","dockerfile","shell","r","sql","sql-msSQL","sql-mySQL",
 
41
  "sql-gql","sql-gpSQL","sql-sparkSQL","sql-esper"
42
  ]
43
 
44
+ def get_model_details(name:str)->Optional[Model]:
45
+ return next((m for m in AVAILABLE_MODELS if m["name"]==name), None)
 
 
 
46
 
47
  def generation_code(
48
+ query:Optional[str],
49
+ file:Optional[str],
50
+ website_url:Optional[str],
51
+ current_model:Model,
52
+ enable_search:bool,
53
+ language:str,
54
+ history:Optional[History],
55
+ )->Tuple[str,History,str,List[Dict[str,str]]]:
56
  query = query or ""
57
  history = history or []
58
  try:
59
+ system_prompt = (
60
+ HTML_SYSTEM_PROMPT if language=="html" else
61
+ TRANSFORMERS_JS_SYSTEM_PROMPT if language=="transformers.js"
62
+ else f"You are an expert {language} developer. Write clean, idiomatic {language} code."
63
+ )
 
 
 
 
 
 
64
  model_id = current_model["id"]
65
  provider = (
66
+ "openai" if model_id.startswith("openai/") or model_id in {"gpt-4","gpt-3.5-turbo"}
67
+ else "gemini" if model_id.startswith(("gemini/","google/"))
68
  else "fireworks-ai" if model_id.startswith("fireworks-ai/")
69
  else "auto"
70
  )
71
 
 
72
  msgs = history_to_messages(history, system_prompt)
73
  context = query
74
+ if file: context += f"\n\n[Attached file]\n{extract_text_from_file(file)[:5000]}"
 
75
  if website_url:
76
  wtext = extract_website_content(website_url)
77
  if not wtext.startswith("Error"):
78
  context += f"\n\n[Website content]\n{wtext[:8000]}"
79
  msgs.append({"role":"user","content":enhance_query_with_search(context, enable_search)})
80
 
 
81
  client = get_inference_client(model_id, provider)
82
+ resp = client.chat.completions.create(model=model_id, messages=msgs, max_tokens=16000, temperature=0.1)
 
 
83
  content = resp.choices[0].message.content
84
 
85
  except Exception as e:
 
87
  history.append((query, err))
88
  return "", history, "", history_to_chatbot_messages(history)
89
 
90
+ if language=="transformers.js":
 
91
  files = parse_transformers_js_output(content)
92
  code = format_transformers_js_output(files)
93
  preview = send_to_sandbox(files.get("index.html",""))
94
  else:
95
  cleaned = remove_code_block(content)
96
  code = apply_search_replace_changes(history[-1][1], cleaned) if history and not history[-1][1].startswith("❌") else cleaned
97
+ preview = send_to_sandbox(code) if language=="html" else ""
98
 
99
  new_hist = history + [(query, code)]
100
  return code, new_hist, preview, history_to_chatbot_messages(new_hist)
101
 
102
+ # --- CSS ---
103
  CUSTOM_CSS = """
104
  body{font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Roboto,sans-serif;}
105
+ #main_title{text-align:center;font-size:2.5rem;margin-top:.5rem;}
106
  #subtitle{text-align:center;color:#4a5568;margin-bottom:2rem;}
107
  .gradio-container{background-color:#f7fafc;}
108
  #gen_btn{box-shadow:0 4px 6px rgba(0,0,0,0.1);}
109
  """
110
 
111
  LOGO_PATH = "assets/logo.png"
 
112
 
113
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"),
114
  css=CUSTOM_CSS,
 
117
  initial_model = AVAILABLE_MODELS[0]
118
  model_state = gr.State(initial_model)
119
 
120
+ if os.path.exists(LOGO_PATH):
121
+ gr.Image(value=LOGO_PATH, height=70, width=70,
122
+ show_label=False, container=False,
123
+ elem_id="banner_logo")
124
 
125
  gr.Markdown("# 🚀 Shasha AI", elem_id="main_title")
126
  gr.Markdown("Your AI partner for generating, modifying, and understanding code.", elem_id="subtitle")
 
128
  with gr.Row():
129
  with gr.Column(scale=1):
130
  gr.Markdown("### 1. Select Model")
131
+ model_dd = gr.Dropdown([m["name"] for m in AVAILABLE_MODELS],
132
+ value=initial_model["name"], label="AI Model")
 
 
 
133
 
134
  gr.Markdown("### 2. Provide Context")
135
  with gr.Tabs():
 
141
  url_in = gr.Textbox(placeholder="https://example.com")
142
 
143
  gr.Markdown("### 3. Configure Output")
144
+ lang_dd = gr.Dropdown(SUPPORTED_LANGUAGES, value="html", label="Target Language")
145
+ search_chk = gr.Checkbox(label="Enable Web Search")
 
146
  with gr.Row():
147
  clr_btn = gr.Button("Clear Session", variant="secondary")
148
  gen_btn = gr.Button("Generate Code", variant="primary", elem_id="gen_btn")