omeryentur's picture
Update app.py
2a9ec14 verified
raw
history blame
3.1 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from peft import PeftModel
from typing import Dict, Any
class LlamaInterface:
def __init__(
self,
base_model_name: str = "meta-llama/Llama-3.2-1B",
lora_model_name: str = "Anlam-Lab/Llama-3.2-1B-it-anlamlab-SA-Chatgpt4mini"
):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
self.model = AutoModelForCausalLM.from_pretrained(
base_model_name,
device_map="auto",
torch_dtype=torch.float16
)
self.model = PeftModel.from_pretrained(self.model, lora_model_name)
self.model.eval()
def generate_response(self, input_text: str) -> str:
if not input_text or not input_text.strip():
return "Error: Please provide valid input text."
try:
inputs = self.tokenizer(
input_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(self.device)
generation_config: Dict[str, Any] = {
"max_length": 512,
"temperature": 0.01,
"do_sample": True,
"top_k": 2,
"top_p": 0.95,
}
with torch.no_grad():
outputs = self.model.generate(**inputs, **generation_config)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return response.split("<|end_header_id|>")[-1].split("<|eot_id|>")[0].strip()
except Exception as e:
return f"Error generating response: {str(e)}"
def create_interface(self) -> gr.Interface:
return gr.Interface(
fn=self.generate_response,
inputs=gr.Textbox(
lines=5,
placeholder="Metninizi buraya girin...",
label="Giriş Metni"
),
outputs=gr.Textbox(
lines=5,
label="Model Çıktısı"
),
title="Anlam-Lab Duygu Analizi",
description="Metin girişi yaparak duygu analizi sonucunu alabilirsiniz.",
examples=[
["Akıllı saati uzun süre kullandım ve şık tasarımı, harika sağlık takibi özellikleri ve uzun pil ömrüyle çok memnun kaldım."],
["Ürünü aldım ama pil ömrü kısa, ekran parlaklığı yetersiz ve sağlık takibi doğru sonuçlar vermedi."],
],
theme="default"
)
def main():
try:
llama_interface = LlamaInterface()
interface = llama_interface.create_interface()
interface.launch(
share=False,
debug=True,
server_name="0.0.0.0",
server_port=7860
)
except Exception as e:
print(f"Error launching interface: {str(e)}")
raise
if __name__ == "__main__":
main()