therpist2 / app.py
hackergeek98's picture
Update app.py
e3819c9 verified
raw
history blame
1.58 kB
import gradio as gr
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
import os
# Get token from environment (automatically loaded from secrets)
hf_token = os.getenv("HF_TOKEN")
login(hf_token)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-pt")
# Load base model on CPU with optimizations
base_model = AutoModelForCausalLM.from_pretrained(
"google/gemma-3-1b-pt",
torch_dtype=torch.bfloat16, # Efficient memory usage
low_cpu_mem_usage=True
)
# Load fine-tuned model
model = PeftModel.from_pretrained(base_model, "hackergeek98/gemma-finetuned")
model = model.to("cpu") # Ensure it runs on CPU
# Chatbot function
def chat(message, history=[]):
messages = [{"role": "user", "content": message}]
input_ids = tokenizer(message, return_tensors="pt").input_ids.to("cpu")
with torch.no_grad(): # Disable gradient calculations for efficiency
output_ids = model.generate(input_ids, max_length=100)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
history.append((message, response)) # Store conversation history
return history, history
# Gradio UI
demo = gr.ChatInterface(
chat,
chatbot=gr.Chatbot(height=400),
additional_inputs=[
gr.Textbox(value="Welcome to the chatbot!", label="System message")
],
title="Fine-Tuned Gemma Chatbot",
description="This chatbot is fine-tuned on Persian text using Gemma.",
)
if __name__ == "__main__":
demo.launch()