File size: 3,354 Bytes
00d2b52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MarianMTModel, MarianTokenizer

# Load models and tokenizers
@st.cache_resource
def load_healthscribe_model():
    model_name = "har1/HealthScribe-Clinical_Note_Generator"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    return model, tokenizer

@st.cache_resource
def load_translation_model(model_name):
    model = MarianMTModel.from_pretrained(model_name)
    tokenizer = MarianTokenizer.from_pretrained(model_name)
    return model, tokenizer

# Initialize models
healthscribe_model, healthscribe_tokenizer = load_healthscribe_model()

# Language selection options
language_options = {
    "English to French": ("en", "fr"),
    "French to English": ("fr", "en"),
    "English to Spanish": ("en", "es"),
    "Spanish to English": ("es", "en"),
    "English to German": ("en", "de"),
    "German to English": ("de", "en"),
    "English to Italian": ("en", "it"),
    "Italian to English": ("it", "en"),
}

# Streamlit UI setup
st.title("Multifunctional Text Processing App")
st.write("This app can generate clinical notes or translate text between languages.")

# Choose task
task = st.selectbox("Select a task:", ["Generate Clinical Note", "Translate Text"])

if task == "Generate Clinical Note":
    st.subheader("Clinical Note Generator")
    input_text = st.text_area("Enter patient information or medical notes:", height=200)

    if st.button("Generate Clinical Note"):
        if input_text.strip():
            # Tokenize and generate
            inputs = healthscribe_tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
            outputs = healthscribe_model.generate(inputs["input_ids"], max_length=512, num_beams=5, early_stopping=True)
            generated_note = healthscribe_tokenizer.decode(outputs[0], skip_special_tokens=True)

            # Display the result
            st.subheader("Generated Clinical Note")
            st.write(generated_note)
        else:
            st.warning("Please enter some text to generate a clinical note.")

elif task == "Translate Text":
    st.subheader("Translation Tool")
    language_pair = st.selectbox("Select language pair", list(language_options.keys()))
    src_lang, tgt_lang = language_options[language_pair]
    
    # Load the corresponding translation model and tokenizer
    model_name = f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}"
    translation_model, translation_tokenizer = load_translation_model(model_name)

    # Input text to translate
    text = st.text_area("Enter text to translate:")

    if st.button("Translate"):
        if text.strip():
            # Prepare the input for the model
            inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True)
            
            # Generate translation
            translation = translation_model.generate(**inputs)
            
            # Decode the output
            translated_text = translation_tokenizer.decode(translation[0], skip_special_tokens=True)
            
            # Display translation
            st.write("**Original Text**:", text)
            st.write("**Translated Text**:", translated_text)
        else:
            st.warning("Please enter some text to translate.")