sabahat-shakeel's picture
Update app.py
9d3a107 verified
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load model and tokenizer
@st.cache_resource
def load_model_and_tokenizer():
model_name = "TheBloke/Mistral-7B-Instruct-v0.2-GPTQ"
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
trust_remote_code=False,
revision="main"
)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
# Check if tokenizer has a pad token, if not add it
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token # Use eos_token as padding token
return model, tokenizer
model, tokenizer = load_model_and_tokenizer()
# Define the prompt template
def generate_prompt(comment):
instructions = f"""Virtual Psychologist, communicates with empathy and understanding, focusing on mental health support and providing advice within its expertise. \
It actively listens, acknowledges emotions, and avoids overly clinical or technical language unless specifically requested. \
It reacts to feedback with warmth and adjusts its tone to match the individual's needs, offering encouragement and validation as appropriate. \
Responses are tailored in length and tone to ensure a supportive and conversational experience.
"""
return f"[INST] {instructions} \n{comment} \n[/INST]"
# Define the response generator
def get_response(comment):
prompt = generate_prompt(comment)
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
# Check if CUDA is available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
outputs = model.generate(
input_ids=inputs["input_ids"].to(device),
attention_mask=inputs["attention_mask"].to(device),
max_new_tokens=140,
pad_token_id=tokenizer.pad_token_id # Ensure padding is handled properly
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response.split("[/INST]")[-1].strip()
# Streamlit app layout
st.title("Virtual Psychologist")
st.markdown("This virtual psychologist offers empathetic responses to your comments or questions. Enter your message below.")
user_input = st.text_input("Your Comment/Question:", placeholder="Type here...")
if user_input:
with st.spinner("Generating response..."):
response = get_response(user_input)
st.write("### Response:")
st.write(response)
st.markdown("Built with ❤️ using [Hugging Face Transformers](https://huggingface.co/transformers/) and [Streamlit](https://streamlit.io/).")