Update src/txagent/txagent.py
Browse files- src/txagent/txagent.py +21 -64
 
    	
        src/txagent/txagent.py
    CHANGED
    
    | 
         @@ -761,17 +761,17 @@ Generate **one summarized sentence** about "function calls' responses" with nece 
     | 
|
| 761 | 
         
             
                    return updated_attributes
         
     | 
| 762 | 
         | 
| 763 | 
         
             
                def run_gradio_chat(self, message: str,
         
     | 
| 764 | 
         
            -
             
     | 
| 765 | 
         
            -
             
     | 
| 766 | 
         
            -
             
     | 
| 767 | 
         
            -
             
     | 
| 768 | 
         
            -
             
     | 
| 769 | 
         
            -
             
     | 
| 770 | 
         
            -
             
     | 
| 771 | 
         
            -
             
     | 
| 772 | 
         
            -
             
     | 
| 773 | 
         
            -
             
     | 
| 774 | 
         
            -
             
     | 
| 775 | 
         
             
                    """
         
     | 
| 776 | 
         
             
                    Generate a streaming response using the llama3-8b model.
         
     | 
| 777 | 
         
             
                    Args:
         
     | 
| 
         @@ -784,29 +784,9 @@ Generate **one summarized sentence** about "function calls' responses" with nece 
     | 
|
| 784 | 
         
             
                    """
         
     | 
| 785 | 
         
             
                    print("\033[1;32;40mstart\033[0m")
         
     | 
| 786 | 
         
             
                    print("len(message)", len(message))
         
     | 
| 787 | 
         
            -
             
     | 
| 788 | 
         
            -
                    if uploaded_files:
         
     | 
| 789 | 
         
            -
                        file_content_blocks = []
         
     | 
| 790 | 
         
            -
                        for file in uploaded_files:
         
     | 
| 791 | 
         
            -
                            if file.name.endswith(".pdf"):
         
     | 
| 792 | 
         
            -
                                try:
         
     | 
| 793 | 
         
            -
                                    import fitz  # PyMuPDF
         
     | 
| 794 | 
         
            -
                                    doc = fitz.open(file.name)
         
     | 
| 795 | 
         
            -
                                    text = ""
         
     | 
| 796 | 
         
            -
                                    for page in doc:
         
     | 
| 797 | 
         
            -
                                        text += page.get_text()
         
     | 
| 798 | 
         
            -
                                    if text.strip():
         
     | 
| 799 | 
         
            -
                                        file_content_blocks.append(f"[FILE CONTENT]\n{text.strip()}\n")
         
     | 
| 800 | 
         
            -
                                except Exception as e:
         
     | 
| 801 | 
         
            -
                                    print(f"Error reading PDF: {e}")
         
     | 
| 802 | 
         
            -
             
     | 
| 803 | 
         
            -
                        if file_content_blocks:
         
     | 
| 804 | 
         
            -
                            message = "\n".join(file_content_blocks) + "\n[USER PROMPT]\n" + message
         
     | 
| 805 | 
         
            -
             
     | 
| 806 | 
         
             
                    if len(message) <= 10:
         
     | 
| 807 | 
         
             
                        yield "Hi, I am TxAgent, an assistant for answering biomedical questions. Please provide a valid message with a string longer than 10 characters."
         
     | 
| 808 | 
         
             
                        return "Please provide a valid message."
         
     | 
| 809 | 
         
            -
             
     | 
| 810 | 
         
             
                    outputs = []
         
     | 
| 811 | 
         
             
                    outputs_str = ''
         
     | 
| 812 | 
         
             
                    last_outputs = []
         
     | 
| 
         @@ -845,7 +825,6 @@ Generate **one summarized sentence** about "function calls' responses" with nece 
     | 
|
| 845 | 
         
             
                                    temperature=temperature)
         
     | 
| 846 | 
         
             
                                history.extend(current_gradio_history)
         
     | 
| 847 | 
         
             
                                if special_tool_call == 'Finish':
         
     | 
| 848 | 
         
            -
                                    yield history
         
     | 
| 849 | 
         
             
                                    next_round = False
         
     | 
| 850 | 
         
             
                                    conversation.extend(function_call_messages)
         
     | 
| 851 | 
         
             
                                    return function_call_messages[0]['content']
         
     | 
| 
         @@ -853,7 +832,6 @@ Generate **one summarized sentence** about "function calls' responses" with nece 
     | 
|
| 853 | 
         
             
                                    history.append(
         
     | 
| 854 | 
         
             
                                        ChatMessage(role="assistant", content=history[-1].content))
         
     | 
| 855 | 
         
             
                                    yield history
         
     | 
| 856 | 
         
            -
                                    next_round = False
         
     | 
| 857 | 
         
             
                                    return history[-1].content
         
     | 
| 858 | 
         
             
                                if (self.enable_summary or token_overflow) and not call_agent:
         
     | 
| 859 | 
         
             
                                    if token_overflow:
         
     | 
| 
         @@ -864,9 +842,8 @@ Generate **one summarized sentence** about "function calls' responses" with nece 
     | 
|
| 864 | 
         
             
                                    enable_summary=enable_summary)
         
     | 
| 865 | 
         
             
                                if function_call_messages is not None:
         
     | 
| 866 | 
         
             
                                    conversation.extend(function_call_messages)
         
     | 
| 867 | 
         
            -
                                     
     | 
| 868 | 
         
            -
             
     | 
| 869 | 
         
            -
                                    yield history
         
     | 
| 870 | 
         
             
                                else:
         
     | 
| 871 | 
         
             
                                    next_round = False
         
     | 
| 872 | 
         
             
                                    conversation.extend(
         
     | 
| 
         @@ -893,18 +870,12 @@ Generate **one summarized sentence** about "function calls' responses" with nece 
     | 
|
| 893 | 
         
             
                                if each.metadata is not None:
         
     | 
| 894 | 
         
             
                                    each.metadata['status'] = 'done'
         
     | 
| 895 | 
         
             
                            if '[FinalAnswer]' in last_thought:
         
     | 
| 896 | 
         
            -
                                 
     | 
| 897 | 
         
            -
                                    '[FinalAnswer]')
         
     | 
| 898 | 
         
             
                                history.append(
         
     | 
| 899 | 
         
            -
                                    ChatMessage(role="assistant",
         
     | 
| 900 | 
         
            -
                                                content=final_thought.strip())
         
     | 
| 901 | 
         
            -
                                )
         
     | 
| 902 | 
         
            -
                                yield history
         
     | 
| 903 | 
         
            -
                                history.append(
         
     | 
| 904 | 
         
            -
                                    ChatMessage(
         
     | 
| 905 | 
         
            -
                                        role="assistant", content="**Answer**:\n"+final_answer.strip())
         
     | 
| 906 | 
         
             
                                )
         
     | 
| 907 | 
         
             
                                yield history
         
     | 
| 
         | 
|
| 908 | 
         
             
                            else:
         
     | 
| 909 | 
         
             
                                history.append(ChatMessage(
         
     | 
| 910 | 
         
             
                                    role="assistant", content=last_thought))
         
     | 
| 
         @@ -920,16 +891,9 @@ Generate **one summarized sentence** about "function calls' responses" with nece 
     | 
|
| 920 | 
         
             
                                    if each.metadata is not None:
         
     | 
| 921 | 
         
             
                                        each.metadata['status'] = 'done'
         
     | 
| 922 | 
         
             
                                if '[FinalAnswer]' in last_thought:
         
     | 
| 923 | 
         
            -
                                     
     | 
| 924 | 
         
            -
                                        '[FinalAnswer]')
         
     | 
| 925 | 
         
            -
                                    history.append(
         
     | 
| 926 | 
         
            -
                                        ChatMessage(role="assistant",
         
     | 
| 927 | 
         
            -
                                                    content=final_thought.strip())
         
     | 
| 928 | 
         
            -
                                    )
         
     | 
| 929 | 
         
            -
                                    yield history
         
     | 
| 930 | 
         
             
                                    history.append(
         
     | 
| 931 | 
         
            -
                                        ChatMessage(
         
     | 
| 932 | 
         
            -
                                            role="assistant", content="**Answer**:\n"+final_answer.strip())
         
     | 
| 933 | 
         
             
                                    )
         
     | 
| 934 | 
         
             
                                    yield history
         
     | 
| 935 | 
         
             
                            else:
         
     | 
| 
         @@ -948,19 +912,12 @@ Generate **one summarized sentence** about "function calls' responses" with nece 
     | 
|
| 948 | 
         
             
                                    each.metadata['status'] = 'done'
         
     | 
| 949 | 
         
             
                            if '[FinalAnswer]' in last_thought or '"name": "Finish",' in last_outputs_str:
         
     | 
| 950 | 
         
             
                                if '[FinalAnswer]' in last_thought:
         
     | 
| 951 | 
         
            -
                                     
     | 
| 952 | 
         
             
                                else:
         
     | 
| 953 | 
         
            -
                                    final_thought = ""
         
     | 
| 954 | 
         
             
                                    final_answer = last_thought
         
     | 
| 955 | 
         
             
                                history.append(
         
     | 
| 956 | 
         
            -
                                    ChatMessage(role="assistant",
         
     | 
| 957 | 
         
            -
                                                content=final_thought.strip())
         
     | 
| 958 | 
         
            -
                                )
         
     | 
| 959 | 
         
            -
                                yield history
         
     | 
| 960 | 
         
            -
                                history.append(
         
     | 
| 961 | 
         
            -
                                    ChatMessage(
         
     | 
| 962 | 
         
            -
                                        role="assistant", content="**Answer**:\n" + final_answer.strip())
         
     | 
| 963 | 
         
             
                                )
         
     | 
| 964 | 
         
             
                                yield history
         
     | 
| 965 | 
         
             
                        else:
         
     | 
| 966 | 
         
            -
                            return None
         
     | 
| 
         | 
|
| 761 | 
         
             
                    return updated_attributes
         
     | 
| 762 | 
         | 
| 763 | 
         
             
                def run_gradio_chat(self, message: str,
         
     | 
| 764 | 
         
            +
                                        history: list,
         
     | 
| 765 | 
         
            +
                                        temperature: float,
         
     | 
| 766 | 
         
            +
                                        max_new_tokens: int,
         
     | 
| 767 | 
         
            +
                                        max_token: int,
         
     | 
| 768 | 
         
            +
                                        call_agent: bool,
         
     | 
| 769 | 
         
            +
                                        conversation: gr.State,
         
     | 
| 770 | 
         
            +
                                        max_round: int = 20,
         
     | 
| 771 | 
         
            +
                                        seed: int = None,
         
     | 
| 772 | 
         
            +
                                        call_agent_level: int = 0,
         
     | 
| 773 | 
         
            +
                                        sub_agent_task: str = None,
         
     | 
| 774 | 
         
            +
                                        uploaded_files: list = None) -> str:
         
     | 
| 775 | 
         
             
                    """
         
     | 
| 776 | 
         
             
                    Generate a streaming response using the llama3-8b model.
         
     | 
| 777 | 
         
             
                    Args:
         
     | 
| 
         | 
|
| 784 | 
         
             
                    """
         
     | 
| 785 | 
         
             
                    print("\033[1;32;40mstart\033[0m")
         
     | 
| 786 | 
         
             
                    print("len(message)", len(message))
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 787 | 
         
             
                    if len(message) <= 10:
         
     | 
| 788 | 
         
             
                        yield "Hi, I am TxAgent, an assistant for answering biomedical questions. Please provide a valid message with a string longer than 10 characters."
         
     | 
| 789 | 
         
             
                        return "Please provide a valid message."
         
     | 
| 
         | 
|
| 790 | 
         
             
                    outputs = []
         
     | 
| 791 | 
         
             
                    outputs_str = ''
         
     | 
| 792 | 
         
             
                    last_outputs = []
         
     | 
| 
         | 
|
| 825 | 
         
             
                                    temperature=temperature)
         
     | 
| 826 | 
         
             
                                history.extend(current_gradio_history)
         
     | 
| 827 | 
         
             
                                if special_tool_call == 'Finish':
         
     | 
| 
         | 
|
| 828 | 
         
             
                                    next_round = False
         
     | 
| 829 | 
         
             
                                    conversation.extend(function_call_messages)
         
     | 
| 830 | 
         
             
                                    return function_call_messages[0]['content']
         
     | 
| 
         | 
|
| 832 | 
         
             
                                    history.append(
         
     | 
| 833 | 
         
             
                                        ChatMessage(role="assistant", content=history[-1].content))
         
     | 
| 834 | 
         
             
                                    yield history
         
     | 
| 
         | 
|
| 835 | 
         
             
                                    return history[-1].content
         
     | 
| 836 | 
         
             
                                if (self.enable_summary or token_overflow) and not call_agent:
         
     | 
| 837 | 
         
             
                                    if token_overflow:
         
     | 
| 
         | 
|
| 842 | 
         
             
                                    enable_summary=enable_summary)
         
     | 
| 843 | 
         
             
                                if function_call_messages is not None:
         
     | 
| 844 | 
         
             
                                    conversation.extend(function_call_messages)
         
     | 
| 845 | 
         
            +
                                    # Hide intermediate output
         
     | 
| 846 | 
         
            +
                                    pass
         
     | 
| 
         | 
|
| 847 | 
         
             
                                else:
         
     | 
| 848 | 
         
             
                                    next_round = False
         
     | 
| 849 | 
         
             
                                    conversation.extend(
         
     | 
| 
         | 
|
| 870 | 
         
             
                                if each.metadata is not None:
         
     | 
| 871 | 
         
             
                                    each.metadata['status'] = 'done'
         
     | 
| 872 | 
         
             
                            if '[FinalAnswer]' in last_thought:
         
     | 
| 873 | 
         
            +
                                _, final_answer = last_thought.split('[FinalAnswer]', 1)
         
     | 
| 
         | 
|
| 874 | 
         
             
                                history.append(
         
     | 
| 875 | 
         
            +
                                    ChatMessage(role="assistant", content=final_answer.strip())
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 876 | 
         
             
                                )
         
     | 
| 877 | 
         
             
                                yield history
         
     | 
| 878 | 
         
            +
                                return
         
     | 
| 879 | 
         
             
                            else:
         
     | 
| 880 | 
         
             
                                history.append(ChatMessage(
         
     | 
| 881 | 
         
             
                                    role="assistant", content=last_thought))
         
     | 
| 
         | 
|
| 891 | 
         
             
                                    if each.metadata is not None:
         
     | 
| 892 | 
         
             
                                        each.metadata['status'] = 'done'
         
     | 
| 893 | 
         
             
                                if '[FinalAnswer]' in last_thought:
         
     | 
| 894 | 
         
            +
                                    _, final_answer = last_thought.split('[FinalAnswer]', 1)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 895 | 
         
             
                                    history.append(
         
     | 
| 896 | 
         
            +
                                        ChatMessage(role="assistant", content=final_answer.strip())
         
     | 
| 
         | 
|
| 897 | 
         
             
                                    )
         
     | 
| 898 | 
         
             
                                    yield history
         
     | 
| 899 | 
         
             
                            else:
         
     | 
| 
         | 
|
| 912 | 
         
             
                                    each.metadata['status'] = 'done'
         
     | 
| 913 | 
         
             
                            if '[FinalAnswer]' in last_thought or '"name": "Finish",' in last_outputs_str:
         
     | 
| 914 | 
         
             
                                if '[FinalAnswer]' in last_thought:
         
     | 
| 915 | 
         
            +
                                    _, final_answer = last_thought.split('[FinalAnswer]', 1)
         
     | 
| 916 | 
         
             
                                else:
         
     | 
| 
         | 
|
| 917 | 
         
             
                                    final_answer = last_thought
         
     | 
| 918 | 
         
             
                                history.append(
         
     | 
| 919 | 
         
            +
                                    ChatMessage(role="assistant", content=final_answer.strip())
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 920 | 
         
             
                                )
         
     | 
| 921 | 
         
             
                                yield history
         
     | 
| 922 | 
         
             
                        else:
         
     | 
| 923 | 
         
            +
                            return None
         
     |