Ali2206 commited on
Commit
6309d92
·
verified ·
1 Parent(s): 9a0af2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -80
app.py CHANGED
@@ -2,12 +2,10 @@ import os
2
  import sys
3
  import random
4
  import gradio as gr
5
- from datetime import datetime
6
 
7
  # Add `src` directory to Python path
8
  sys.path.append(os.path.join(os.path.dirname(__file__), "src"))
9
 
10
- # Import your agent class from src/txagent/txagent.py
11
  from txagent.txagent import TxAgent
12
 
13
  # ==== Environment Setup ====
@@ -63,85 +61,86 @@ question_examples = [
63
  ["A 30-year-old patient is on Prozac for depression and now diagnosed with WHIM syndrome. Is Xolremdi suitable?"]
64
  ]
65
 
66
- # === Initialize the model ===
67
- agent = TxAgent(
68
- model_name,
69
- rag_model_name,
70
- tool_files_dict=new_tool_files,
71
- force_finish=True,
72
- enable_checker=True,
73
- step_rag_num=10,
74
- seed=100,
75
- additional_default_tools=["DirectResponse", "RequireClarification"]
76
- )
77
- agent.init_model()
78
-
79
- # === Gradio interface logic ===
80
- def handle_chat(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
81
- return agent.run_gradio_chat(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round)
82
-
83
- def update_seed():
84
- seed = random.randint(0, 10000)
85
- return agent.update_parameters(seed=seed)
86
-
87
- # ✅ FIXED: handle_retry with return, no yield
88
- def handle_retry(history, retry_data, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
89
- update_seed()
90
- new_history = history[:retry_data.index]
91
- previous_prompt = history[retry_data.index]["content"]
92
- result = agent.run_gradio_chat(
93
- new_history + [{"role": "user", "content": previous_prompt}],
94
- temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round
95
- )
96
- # If it's a generator, convert to list to avoid Gradio errors
97
- if hasattr(result, "__iter__") and not isinstance(result, (str, dict, list)):
98
- result = list(result)
99
- return result
100
-
101
- # ===== Build Gradio Interface =====
102
- with gr.Blocks(css=css) as demo:
103
- gr.Markdown(DESCRIPTION)
104
- gr.Markdown(INTRO)
105
-
106
- temperature = gr.Slider(0, 1, step=0.1, value=0.3, label="Temperature")
107
- max_new_tokens = gr.Slider(128, 4096, step=1, value=1024, label="Max New Tokens")
108
- max_tokens = gr.Slider(128, 32000, step=1, value=8192, label="Max Total Tokens")
109
- max_round = gr.Slider(1, 50, step=1, value=30, label="Max Rounds")
110
- multi_agent = gr.Checkbox(label="Enable Multi-agent Reasoning", value=False)
111
- conversation_state = gr.State([])
112
-
113
- chatbot = gr.Chatbot(
114
- label="TxAgent",
115
- placeholder=PLACEHOLDER,
116
- height=700,
117
- type="messages",
118
- show_copy_button=True
119
- )
120
-
121
- # ✅ Retry logic added safely
122
- chatbot.retry(
123
- handle_retry,
124
- chatbot, chatbot,
125
- temperature, max_new_tokens, max_tokens,
126
- multi_agent, conversation_state, max_round
127
- )
128
-
129
- gr.ChatInterface(
130
- fn=handle_chat,
131
- chatbot=chatbot,
132
- additional_inputs=[
133
  temperature, max_new_tokens, max_tokens,
134
  multi_agent, conversation_state, max_round
135
- ],
136
- examples=question_examples,
137
- css=chat_css,
138
- cache_examples=False,
139
- fill_height=True,
140
- fill_width=True,
141
- stop_btn=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  )
 
143
 
144
- gr.Markdown(LICENSE)
145
-
146
- # ✅ Ensure launch works on Hugging Face Spaces
147
- demo.launch()
 
2
  import sys
3
  import random
4
  import gradio as gr
 
5
 
6
  # Add `src` directory to Python path
7
  sys.path.append(os.path.join(os.path.dirname(__file__), "src"))
8
 
 
9
  from txagent.txagent import TxAgent
10
 
11
  # ==== Environment Setup ====
 
61
  ["A 30-year-old patient is on Prozac for depression and now diagnosed with WHIM syndrome. Is Xolremdi suitable?"]
62
  ]
63
 
64
+ # Initialize agent placeholder
65
+ agent = None
66
+
67
+ # ===== Build Gradio UI =====
68
+ def create_ui():
69
+ with gr.Blocks(css=css) as demo:
70
+ gr.Markdown(DESCRIPTION)
71
+ gr.Markdown(INTRO)
72
+
73
+ temperature = gr.Slider(0, 1, step=0.1, value=0.3, label="Temperature")
74
+ max_new_tokens = gr.Slider(128, 4096, step=1, value=1024, label="Max New Tokens")
75
+ max_tokens = gr.Slider(128, 32000, step=1, value=8192, label="Max Total Tokens")
76
+ max_round = gr.Slider(1, 50, step=1, value=30, label="Max Rounds")
77
+ multi_agent = gr.Checkbox(label="Enable Multi-agent Reasoning", value=False)
78
+ conversation_state = gr.State([])
79
+
80
+ chatbot = gr.Chatbot(
81
+ label="TxAgent",
82
+ placeholder=PLACEHOLDER,
83
+ height=700,
84
+ type="messages",
85
+ show_copy_button=True
86
+ )
87
+
88
+ # Retry logic
89
+ def handle_retry(history, retry_data, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
90
+ agent.update_parameters(seed=random.randint(0, 10000))
91
+ new_history = history[:retry_data.index]
92
+ prompt = history[retry_data.index]["content"]
93
+ result = agent.run_gradio_chat(
94
+ new_history + [{"role": "user", "content": prompt}],
95
+ temperature, max_new_tokens, max_tokens,
96
+ multi_agent, conversation, max_round
97
+ )
98
+ if hasattr(result, "__iter__") and not isinstance(result, (str, dict, list)):
99
+ result = list(result)
100
+ return result
101
+
102
+ chatbot.retry(
103
+ handle_retry,
104
+ chatbot, chatbot,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  temperature, max_new_tokens, max_tokens,
106
  multi_agent, conversation_state, max_round
107
+ )
108
+
109
+ # Main handler
110
+ def handle_chat(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
111
+ return agent.run_gradio_chat(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round)
112
+
113
+ gr.ChatInterface(
114
+ fn=handle_chat,
115
+ chatbot=chatbot,
116
+ additional_inputs=[
117
+ temperature, max_new_tokens, max_tokens,
118
+ multi_agent, conversation_state, max_round
119
+ ],
120
+ examples=question_examples,
121
+ css=chat_css,
122
+ cache_examples=False,
123
+ fill_height=True,
124
+ fill_width=True,
125
+ stop_btn=True
126
+ )
127
+
128
+ gr.Markdown(LICENSE)
129
+ return demo
130
+
131
+ # === VLLM-safe entry point ===
132
+ if __name__ == "__main__":
133
+ agent = TxAgent(
134
+ model_name,
135
+ rag_model_name,
136
+ tool_files_dict=new_tool_files,
137
+ force_finish=True,
138
+ enable_checker=True,
139
+ step_rag_num=10,
140
+ seed=100,
141
+ additional_default_tools=["DirectResponse", "RequireClarification"]
142
  )
143
+ agent.init_model()
144
 
145
+ demo = create_ui()
146
+ demo.launch()