Kkordik's picture
Update app.py
48de1c3 verified
import gradio as gr
from huggingface_hub import snapshot_download
from pathlib import Path
import spaces
import subprocess
import os
# Install required packages
subprocess.run('pip install causal-conv1d --no-build-isolation', env={'CAUSAL_CONV1D_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
subprocess.run('pip install mamba-ssm --no-build-isolation', env={'MAMBA_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
subprocess.run('pip install mistral_inference --no-build-isolation', env={'MISTRAL_INFERENCE_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# Import after installation
from mistral_inference.mamba import Mamba
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage, AssistantMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
# Download the model
mistral_models_path = Path.home().joinpath('mistral_models', 'mamba-codestral-7B-v0.1')
mistral_models_path.mkdir(parents=True, exist_ok=True)
snapshot_download(repo_id="mistralai/mamba-codestral-7B-v0.1",
allow_patterns=["params.json", "consolidated.safetensors", "tokenizer.model.v3"],
local_dir=mistral_models_path)
MODEL_PATH = str(mistral_models_path)
# Load model and tokenizer
tokenizer = MistralTokenizer.from_file(os.path.join(MODEL_PATH, "tokenizer.model.v3"))
model = Mamba.from_folder(MODEL_PATH)
@spaces.GPU()
def generate_response(message, history):
# Convert history to the format expected by the model
messages = []
for human, assistant in history:
messages.append(UserMessage(content=human))
messages.append(AssistantMessage(content=assistant))
messages.append(UserMessage(content=message))
# Create chat completion request
completion_request = ChatCompletionRequest(messages=messages)
# Tokenize input
tokens = tokenizer.encode_chat_completion(completion_request).tokens
# Generate response
out_tokens = generate([tokens], model, max_tokens=256, temperature=0.7, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
# Decode response
result = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])
return result
# Gradio interface
iface = gr.ChatInterface(
generate_response,
title="Mamba Codestral Chat (ZeroGPU)",
description="Chat with the Mamba Codestral 7B model using Hugging Face Spaces ZeroGPU feature.",
)
if __name__ == "__main__":
iface.launch()