llamaSMS / gemma.py
Tri4's picture
Update gemma.py
1a69621 verified
"""
#git+https://github.com/huggingface/transformers
transformers==4.43.1
huggingface_hub
bitsandbytes
accelerate
langchain
torch
flask
gunicorn
twilio
baseten
spaces
"""
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from flask import Flask, request, jsonify
import torch
import os
app = Flask(__name__)
print("Hello welcome to Sema AI", flush=True) # Flush to ensure immediate output
@app.route("/")
def hello():
return "hello 🤗, Welcome to Sema AI Chat Service."
# Get Hugging Face credentials from environment variables
HF_TOKEN = os.getenv('HF_TOKEN')
if not HF_TOKEN:
print("Missing Hugging Face token", flush=True)
model_id = "google/gemma-2-2b-it"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Load tokenizer and model with authentication token
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16,
token=HF_TOKEN
)
app_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer
)
@app.route("/generate_text", methods=["POST"])
def generate_text():
data = request.json
prompt = data.get("prompt", "")
max_new_tokens = data.get("max_new_tokens", 1000)
do_sample = data.get("do_sample", True)
temperature = data.get("temperature", 0.1)
top_k = data.get("top_k", 50)
top_p = data.get("top_p", 0.95)
print(f"{prompt}: ")
try:
outputs = app_pipeline(
prompt,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p
)
response_text = outputs[0]["generated_text"]
except Exception as e:
return jsonify({"error": str(e)}), 500
return jsonify({"response": response_text})
if __name__ == "__main__":
app.run(debug=False, port=8888)