File size: 1,668 Bytes
99e744f
70f6a45
32e6acc
99e744f
2ac4210
55a77e3
99e744f
dad185d
99e744f
 
00742e9
2ac4210
052ff21
 
99e744f
 
 
 
 
 
 
173b5f1
 
 
 
 
72b02de
2ac4210
638094e
3ef330a
2ac4210
 
173b5f1
 
 
 
 
 
 
 
 
99e744f
 
 
173b5f1
99e744f
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline, AutoModelForCausalLM
from transformers import LEDForConditionalGeneration, LEDTokenizer
from langchain_openai import OpenAI
# from huggingface_hub import login
from dotenv import load_dotenv
from logging import getLogger
# import streamlit as st
import torch

load_dotenv()
hf_token = os.environ.get("HF_TOKEN")
# # hf_token = st.secrets["HF_TOKEN"]
# login(token=hf_token)
logger = getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"

def get_local_model(model_name_or_path:str)->pipeline:

    #print(f"Model is running on {device}")

    #!!!!!Removed for Llama model
    # tokenizer = AutoTokenizer.from_pretrained( 
    #     model_name_or_path,
    #     token = hf_token
    # )
    model = AutoModelForCausalLM.from_pretrained( 
        model_name_or_path,
        torch_dtype=torch.bfloat16,
        # load_in_4bit = True,
        token = hf_token
    )
    #!!!!!!!!!!!!!!!!!!!!!Removed for Llama model!!!!!!!!!!!!!!!!!!!!!!!
    # pipe = pipeline(
    #     task = "summarization",
    #     model=model,
    #     tokenizer=tokenizer,
    #     device = device,
    #     max_new_tokens = 400,
    #     model_kwargs = {"max_length":16384, "max_new_tokens": 512},
    # )

    logger.info(f"Summarization pipeline created and loaded to {device}")
   
    return model

def get_endpoint(api_key:str):

    llm = OpenAI(openai_api_key=api_key)
    return llm

def get_model(model_type,model_name_or_path,api_key = None):
    if model_type == "openai":
        return get_endpoint(api_key)
    else: 
        return get_local_model(model_name_or_path)