Phoenix21's picture
Update app.py
b4957d9 verified
raw
history blame
2.07 kB
import os
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel, PeftConfig
import uvicorn
from huggingface_hub import login
# Authenticate with Hugging Face Hub
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
else:
raise ValueError("Hugging Face token not found. Please set the HF_TOKEN environment variable.")
# Define a Pydantic model for request validation
class Query(BaseModel):
text: str
app = FastAPI(title="Financial Chatbot API")
# Load the base model
base_model_name = "meta-llama/Llama-3.2-3B" # Update if using a different base model
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
device_map="auto",
trust_remote_code=True
)
# Load adapter from your checkpoint with a workaround for 'eva_config'
peft_model_id = "Phoenix21/llama-3-2-3b-finetuned-finance_checkpoint2"
# Load the PEFT configuration first
peft_config = PeftConfig.from_pretrained(peft_model_id)
# Remove 'eva_config' if it exists in the configuration
peft_config_dict = peft_config.to_dict()
if "eva_config" in peft_config_dict:
peft_config_dict.pop("eva_config")
peft_config = PeftConfig.from_dict(peft_config_dict)
# Load the adapter using the filtered configuration
model = PeftModel.from_pretrained(model, peft_model_id, config=peft_config)
# Load tokenizer from the base model
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
# Create a text-generation pipeline
chat_pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=256,
temperature=0.7,
top_p=0.95,
)
@app.post("/generate")
def generate(query: Query):
prompt = f"Question: {query.text}\nAnswer: "
response = chat_pipe(prompt)[0]["generated_text"]
return {"response": response}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)