|
import streamlit as st |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
|
|
|
|
@st.cache_resource |
|
def load_model_and_tokenizer(): |
|
model_name = "TheBloke/Mistral-7B-Instruct-v0.2-GPTQ" |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
device_map="auto", |
|
trust_remote_code=False, |
|
revision="main" |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
return model, tokenizer |
|
|
|
model, tokenizer = load_model_and_tokenizer() |
|
|
|
|
|
def generate_prompt(comment): |
|
instructions = f"""Virtual Psychologist, communicates with empathy and understanding, focusing on mental health support and providing advice within its expertise. \ |
|
It actively listens, acknowledges emotions, and avoids overly clinical or technical language unless specifically requested. \ |
|
It reacts to feedback with warmth and adjusts its tone to match the individual's needs, offering encouragement and validation as appropriate. \ |
|
Responses are tailored in length and tone to ensure a supportive and conversational experience. |
|
""" |
|
return f"[INST] {instructions} \n{comment} \n[/INST]" |
|
|
|
|
|
def get_response(comment): |
|
prompt = generate_prompt(comment) |
|
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
outputs = model.generate( |
|
input_ids=inputs["input_ids"].to(device), |
|
attention_mask=inputs["attention_mask"].to(device), |
|
max_new_tokens=140, |
|
pad_token_id=tokenizer.pad_token_id |
|
) |
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return response.split("[/INST]")[-1].strip() |
|
|
|
|
|
st.title("Virtual Psychologist") |
|
st.markdown("This virtual psychologist offers empathetic responses to your comments or questions. Enter your message below.") |
|
|
|
user_input = st.text_input("Your Comment/Question:", placeholder="Type here...") |
|
|
|
if user_input: |
|
with st.spinner("Generating response..."): |
|
response = get_response(user_input) |
|
st.write("### Response:") |
|
st.write(response) |
|
|
|
st.markdown("Built with ❤️ using [Hugging Face Transformers](https://huggingface.co/transformers/) and [Streamlit](https://streamlit.io/).") |
|
|