megabeam-chat / app.py
asimsultan's picture
Update app.py
b14fe3a verified
raw
history blame
649 Bytes
from transformers import AutoTokenizer, AutoModelForCausalLM
from vllm import LLM, SamplingParams
import gradio as gr
model_name = "aws-prototyping/MegaBeam-Mistral-7B-512k"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
device_map="auto"
)
def chat(prompt: str):
llm = LLM(model=model_name)
sampling = SamplingParams(temperature=0.7, max_tokens=512)
outputs = llm.generate([prompt], sampling)
return outputs[0].outputs[0].text
iface = gr.Interface(fn=chat, inputs="text", outputs="text")
iface.launch()