Spaces:
Running
Running
# app.py | |
import os | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
# Initialize model and tokenizer | |
def load_model(model_size: str = "32B"): | |
""" | |
Load model and tokenizer based on size selection | |
Note: You'll need to replace these with actual HuggingFace model IDs | |
""" | |
model_map = { | |
"0.5B": "Qwen/Qwen-0.5B", | |
"1.5B": "Qwen/Qwen-1.5B", | |
"7B": "Qwen/Qwen-7B", | |
# ... add other model sizes as needed | |
} | |
model_id = model_map.get(model_size, "Qwen/Qwen-7B") # default to 7B if size not found | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
return model, tokenizer | |
def process_query(query: str, model_size: str = "7B") -> str: | |
""" | |
Process a single query and return the response | |
""" | |
if not query: | |
return "" | |
try: | |
model, tokenizer = load_model(model_size) | |
# Prepare the input | |
inputs = tokenizer(query, return_tensors="pt").to(model.device) | |
# Generate response | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.9, | |
pad_token_id=tokenizer.pad_token_id | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response.replace(query, "").strip() | |
except Exception as e: | |
return f"Error: {str(e)}" | |
def main(): | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
model_size = gr.Radio( | |
choices=["0.5B", "1.5B", "3B", "7B", "14B", "32B"], | |
label="Qwen2.5-Coder Model Size:", | |
value="32B" | |
) | |
with gr.Row(): | |
input_text = gr.Textbox( | |
lines=5, | |
label="Input", | |
placeholder="Enter your query here..." | |
) | |
with gr.Row(): | |
output_text = gr.Textbox( | |
lines=10, | |
label="Output" | |
) | |
submit_btn = gr.Button("Generate") | |
submit_btn.click( | |
fn=process_query, | |
inputs=[input_text, model_size], | |
outputs=output_text | |
) | |
demo.launch(max_threads=5) | |
if __name__ == "__main__": | |
main() |