Ali2206 commited on
Commit
9a0af2f
Β·
verified Β·
1 Parent(s): 682a29e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -76
app.py CHANGED
@@ -4,11 +4,11 @@ import random
4
  import gradio as gr
5
  from datetime import datetime
6
 
7
- # Add src to Python path
8
  sys.path.append(os.path.join(os.path.dirname(__file__), "src"))
9
 
10
- # Adjust to match your file structure
11
- from txagent.txagent import TxAgent # e.g., src/txagent/txagent.py
12
 
13
  # ==== Environment Setup ====
14
  current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -63,90 +63,85 @@ question_examples = [
63
  ["A 30-year-old patient is on Prozac for depression and now diagnosed with WHIM syndrome. Is Xolremdi suitable?"]
64
  ]
65
 
66
- # ====== Main Application Entrypoint ======
67
- if __name__ == "__main__":
68
- # === Initialize the model (inside __main__) ===
69
- agent = TxAgent(
70
- model_name,
71
- rag_model_name,
72
- tool_files_dict=new_tool_files,
73
- force_finish=True,
74
- enable_checker=True,
75
- step_rag_num=10,
76
- seed=100,
77
- additional_default_tools=["DirectResponse", "RequireClarification"]
78
- )
79
- agent.init_model()
80
-
81
- # === Gradio interface logic ===
82
- def handle_chat(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
83
- return agent.run_gradio_chat(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round)
84
-
85
- def update_seed():
86
- seed = random.randint(0, 10000)
87
- return agent.update_parameters(seed=seed)
88
-
89
- # βœ… FIXED: retry must return, not yield
90
  def handle_retry(history, retry_data, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
91
  update_seed()
92
  new_history = history[:retry_data.index]
93
  previous_prompt = history[retry_data.index]["content"]
94
-
95
- # βœ… This MUST return, not yield
96
  result = agent.run_gradio_chat(
97
  new_history + [{"role": "user", "content": previous_prompt}],
98
  temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round
99
  )
100
-
101
- # If your agent returns a generator, consume it into a list or string
102
  if hasattr(result, "__iter__") and not isinstance(result, (str, dict, list)):
103
  result = list(result)
104
-
105
  return result
106
 
107
- # ===== Build Gradio Interface =====
108
- with gr.Blocks(css=css) as demo:
109
- gr.Markdown(DESCRIPTION)
110
- gr.Markdown(INTRO)
111
-
112
- temperature = gr.Slider(0, 1, step=0.1, value=0.3, label="Temperature")
113
- max_new_tokens = gr.Slider(128, 4096, step=1, value=1024, label="Max New Tokens")
114
- max_tokens = gr.Slider(128, 32000, step=1, value=8192, label="Max Total Tokens")
115
- max_round = gr.Slider(1, 50, step=1, value=30, label="Max Rounds")
116
- multi_agent = gr.Checkbox(label="Enable Multi-agent Reasoning", value=False)
117
- conversation_state = gr.State([])
118
-
119
- chatbot = gr.Chatbot(
120
- label="TxAgent",
121
- placeholder=PLACEHOLDER,
122
- height=700,
123
- type="messages",
124
- show_copy_button=True
125
- )
126
-
127
- # βœ… Retry now fixed
128
- chatbot.retry(
129
- handle_retry,
130
- chatbot, chatbot,
 
 
 
 
 
 
 
 
131
  temperature, max_new_tokens, max_tokens,
132
  multi_agent, conversation_state, max_round
133
- )
134
-
135
- gr.ChatInterface(
136
- fn=handle_chat,
137
- chatbot=chatbot,
138
- additional_inputs=[
139
- temperature, max_new_tokens, max_tokens,
140
- multi_agent, conversation_state, max_round
141
- ],
142
- examples=question_examples,
143
- css=chat_css,
144
- cache_examples=False,
145
- fill_height=True,
146
- fill_width=True,
147
- stop_btn=True
148
- )
149
-
150
- gr.Markdown(LICENSE)
151
-
152
- demo.launch()
 
4
  import gradio as gr
5
  from datetime import datetime
6
 
7
+ # Add `src` directory to Python path
8
  sys.path.append(os.path.join(os.path.dirname(__file__), "src"))
9
 
10
+ # Import your agent class from src/txagent/txagent.py
11
+ from txagent.txagent import TxAgent
12
 
13
  # ==== Environment Setup ====
14
  current_dir = os.path.dirname(os.path.abspath(__file__))
 
63
  ["A 30-year-old patient is on Prozac for depression and now diagnosed with WHIM syndrome. Is Xolremdi suitable?"]
64
  ]
65
 
66
+ # === Initialize the model ===
67
+ agent = TxAgent(
68
+ model_name,
69
+ rag_model_name,
70
+ tool_files_dict=new_tool_files,
71
+ force_finish=True,
72
+ enable_checker=True,
73
+ step_rag_num=10,
74
+ seed=100,
75
+ additional_default_tools=["DirectResponse", "RequireClarification"]
76
+ )
77
+ agent.init_model()
78
+
79
+ # === Gradio interface logic ===
80
+ def handle_chat(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
81
+ return agent.run_gradio_chat(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round)
82
+
83
+ def update_seed():
84
+ seed = random.randint(0, 10000)
85
+ return agent.update_parameters(seed=seed)
86
+
87
+ # βœ… FIXED: handle_retry with return, no yield
 
 
88
  def handle_retry(history, retry_data, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
89
  update_seed()
90
  new_history = history[:retry_data.index]
91
  previous_prompt = history[retry_data.index]["content"]
 
 
92
  result = agent.run_gradio_chat(
93
  new_history + [{"role": "user", "content": previous_prompt}],
94
  temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round
95
  )
96
+ # If it's a generator, convert to list to avoid Gradio errors
 
97
  if hasattr(result, "__iter__") and not isinstance(result, (str, dict, list)):
98
  result = list(result)
 
99
  return result
100
 
101
+ # ===== Build Gradio Interface =====
102
+ with gr.Blocks(css=css) as demo:
103
+ gr.Markdown(DESCRIPTION)
104
+ gr.Markdown(INTRO)
105
+
106
+ temperature = gr.Slider(0, 1, step=0.1, value=0.3, label="Temperature")
107
+ max_new_tokens = gr.Slider(128, 4096, step=1, value=1024, label="Max New Tokens")
108
+ max_tokens = gr.Slider(128, 32000, step=1, value=8192, label="Max Total Tokens")
109
+ max_round = gr.Slider(1, 50, step=1, value=30, label="Max Rounds")
110
+ multi_agent = gr.Checkbox(label="Enable Multi-agent Reasoning", value=False)
111
+ conversation_state = gr.State([])
112
+
113
+ chatbot = gr.Chatbot(
114
+ label="TxAgent",
115
+ placeholder=PLACEHOLDER,
116
+ height=700,
117
+ type="messages",
118
+ show_copy_button=True
119
+ )
120
+
121
+ # βœ… Retry logic added safely
122
+ chatbot.retry(
123
+ handle_retry,
124
+ chatbot, chatbot,
125
+ temperature, max_new_tokens, max_tokens,
126
+ multi_agent, conversation_state, max_round
127
+ )
128
+
129
+ gr.ChatInterface(
130
+ fn=handle_chat,
131
+ chatbot=chatbot,
132
+ additional_inputs=[
133
  temperature, max_new_tokens, max_tokens,
134
  multi_agent, conversation_state, max_round
135
+ ],
136
+ examples=question_examples,
137
+ css=chat_css,
138
+ cache_examples=False,
139
+ fill_height=True,
140
+ fill_width=True,
141
+ stop_btn=True
142
+ )
143
+
144
+ gr.Markdown(LICENSE)
145
+
146
+ # βœ… Ensure launch works on Hugging Face Spaces
147
+ demo.launch()