import streamlit as st from transformers import AutoModelForCausalLM, AutoTokenizer import torch import os # Set page configuration as the first Streamlit command st.set_page_config( page_title="FDA NDA Submission Assistant", layout="centered", initial_sidebar_state="auto" ) # Apply custom CSS for retro 80s green theme def apply_custom_css(): try: with open("style.css") as f: st.markdown(f"", unsafe_allow_html=True) except FileNotFoundError: st.warning("style.css not found. Using default styles.") @st.cache_resource def load_model(): model_path = "HuggingFaceH4/zephyr-7b-beta" peft_model_path = "yitzashapiro/FDA-guidance-zephyr-7b-beta-PEFT" try: HF_API_TOKEN = os.getenv("HF_API_TOKEN") st.write("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True, use_auth_token=HF_API_TOKEN # Use token for private models ) st.write("Loading model...") model = AutoModelForCausalLM.from_pretrained( model_path, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True, use_auth_token=HF_API_TOKEN # Use token for private models ).eval() st.write("Loading PEFT adapter...") model.load_adapter(peft_model_path) st.success("Model loaded successfully.") except Exception as e: st.error(f"Error loading model: {e}") st.stop() return tokenizer, model def generate_response(tokenizer, model, user_input): messages = [ {"role": "user", "content": user_input} ] try: if hasattr(tokenizer, 'apply_chat_template'): input_ids = tokenizer.apply_chat_template( conversation=messages, max_length=45, tokenize=True, add_generation_prompt=True, return_tensors='pt' ) else: input_ids = tokenizer( user_input, return_tensors='pt', truncation=True, max_length=45 )['input_ids'] pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 attention_mask = (input_ids != pad_token_id).long() output_ids = model.generate( input_ids.to(model.device), max_length=2048, max_new_tokens=500, attention_mask=attention_mask.to(model.device) ) response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True) return response except Exception as e: st.error(f"Error generating response: {e}") return "An error occurred while generating the response." def main(): apply_custom_css() st.title("FDA NDA Submission Assistant") st.write("Ask the model about submitting an NDA to the FDA.") tokenizer, model = load_model() user_input = st.text_input("Enter your question:", "What's the best way to submit an NDA to the FDA?") if st.button("Generate Response"): if user_input.strip() == "": st.error("Please enter a valid question.") else: try: with st.spinner("Generating response..."): response = generate_response(tokenizer, model, user_input) st.success("Response:") st.write(response) except Exception as e: st.error(f"An error occurred: {e}") if __name__ == "__main__": main()