Spaces:
Runtime error
Runtime error
Fangrui Liu
commited on
Commit
Β·
d5a4cb4
1
Parent(s):
526644e
fix callback
Browse files- callbacks/arxiv_callbacks.py +10 -3
callbacks/arxiv_callbacks.py
CHANGED
|
@@ -2,6 +2,7 @@ import streamlit as st
|
|
| 2 |
from typing import Dict, Any
|
| 3 |
from sql_formatter.core import format_sql
|
| 4 |
from langchain.callbacks.streamlit.streamlit_callback_handler import StreamlitCallbackHandler
|
|
|
|
| 5 |
|
| 6 |
class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
|
| 7 |
def __init__(self) -> None:
|
|
@@ -62,8 +63,14 @@ class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler):
|
|
| 62 |
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
|
| 63 |
pass
|
| 64 |
|
| 65 |
-
def
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
st.write('We generated Vector SQL for you:')
|
| 68 |
st.markdown(f'''```sql\n{format_sql(text, max_len=80)}\n```''')
|
| 69 |
print(f"Vector SQL: {text}")
|
|
@@ -83,4 +90,4 @@ class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler):
|
|
| 83 |
self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
|
| 84 |
self.status_bar = st.empty()
|
| 85 |
self.prog_value = 0
|
| 86 |
-
self.prog_interval = 0.1
|
|
|
|
| 2 |
from typing import Dict, Any
|
| 3 |
from sql_formatter.core import format_sql
|
| 4 |
from langchain.callbacks.streamlit.streamlit_callback_handler import StreamlitCallbackHandler
|
| 5 |
+
from langchain.schema.output import LLMResult
|
| 6 |
|
| 7 |
class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
|
| 8 |
def __init__(self) -> None:
|
|
|
|
| 63 |
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
|
| 64 |
pass
|
| 65 |
|
| 66 |
+
def on_llm_end(
|
| 67 |
+
self,
|
| 68 |
+
response: LLMResult,
|
| 69 |
+
*args,
|
| 70 |
+
**kwargs,
|
| 71 |
+
):
|
| 72 |
+
text = response.generations[0][0].text
|
| 73 |
+
if text.replace(' ', '').upper().startswith('SELECT'):
|
| 74 |
st.write('We generated Vector SQL for you:')
|
| 75 |
st.markdown(f'''```sql\n{format_sql(text, max_len=80)}\n```''')
|
| 76 |
print(f"Vector SQL: {text}")
|
|
|
|
| 90 |
self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
|
| 91 |
self.status_bar = st.empty()
|
| 92 |
self.prog_value = 0
|
| 93 |
+
self.prog_interval = 0.1
|