Usage



from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = PeftConfig.from_pretrained("Ashishkr/llama2-qrecc-context-resolution")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
model = PeftModel.from_pretrained(model, "Ashishkr/llama2-qrecc-context-resolution").to(device)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

def response_generate(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    prompt: str,
    max_new_tokens: int = 128,
    temperature: float = 0.7,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    inputs = tokenizer(
        [prompt],
        return_tensors="pt",
        return_token_type_ids=False,
    ).to(
        device
    )

    with torch.autocast("cuda", dtype=torch.bfloat16):
        response = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            return_dict_in_generate=True,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )

    decoded_output = tokenizer.decode(
        response["sequences"][0],
        skip_special_tokens=True,
    )

    return decoded_output

prompt = """ Strictly use the context provided, to generate the repsonse. No additional information to be added. Re-write the user query using the context . 
>>CONTEXT<<Where did jessica go to school? Where did she work at?>>USER<<What did she do next for work?>>REWRITE<<"""

response = response_generate(
    model,
    tokenizer,
    prompt,
    max_new_tokens=20,
    temperature=0.1,
)

print(response)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.