sabahat-shakeel's picture
Create app.py
ef3db06 verified
raw
history blame
2.3 kB
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load model and tokenizer
@st.cache_resource
def load_model_and_tokenizer():
model_name = "TheBloke/Mistral-7B-Instruct-v0.2-GPTQ"
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
trust_remote_code=False,
revision="main"
)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
return model, tokenizer
model, tokenizer = load_model_and_tokenizer()
# Define the prompt template
def generate_prompt(comment):
instructions = f"""Virtual Psychologist, communicates with empathy and understanding, focusing on mental health support and providing advice within its expertise. \
It actively listens, acknowledges emotions, and avoids overly clinical or technical language unless specifically requested. \
It reacts to feedback with warmth and adjusts its tone to match the individual's needs, offering encouragement and validation as appropriate. \
Responses are tailored in length and tone to ensure a supportive and conversational experience.
"""
return f"[INST] {instructions} \n{comment} \n[/INST]"
# Define the response generator
def get_response(comment):
prompt = generate_prompt(comment)
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
outputs = model.generate(
input_ids=inputs["input_ids"].to("cuda"),
attention_mask=inputs["attention_mask"].to("cuda"),
max_new_tokens=140,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response.split("[/INST]")[-1].strip()
# Streamlit app layout
st.title("Virtual Psychologist")
st.markdown("This virtual psychologist offers empathetic responses to your comments or questions. Enter your message below.")
user_input = st.text_input("Your Comment/Question:", placeholder="Type here...")
if user_input:
with st.spinner("Generating response..."):
response = get_response(user_input)
st.write("### Response:")
st.write(response)
st.markdown("Built with ❤️ using [Hugging Face Transformers](https://huggingface.co/transformers/) and [Streamlit](https://streamlit.io/).")