Test_E5 / app.py
Prathmesh48's picture
Update app.py
dc0d8d3 verified
raw
history blame
1.22 kB
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel
# Load the tokenizer and model
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained('Alibaba-NLP/gte-base-en-v1.5', trust_remote_code=True)
model = AutoModel.from_pretrained('Alibaba-NLP/gte-base-en-v1.5', trust_remote_code=True)
model.to('cpu')
return tokenizer, model
tokenizer, model = load_model()
def extract_embeddings(text, tokenizer, model):
# Tokenize the input text
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to('cpu') for k, v in inputs.items()}
# Get the model's outputs
with torch.no_grad():
outputs = model(**inputs)
# Extract the embeddings (use the output of the last hidden state)
embeddings = outputs.last_hidden_state.mean(dim=1)
return embeddings.squeeze().cpu().numpy()
# Streamlit app
st.title("Text Embeddings Extractor")
text = st.text_area("Enter text to extract embeddings:", "This is an example sentence.")
if st.button("Extract Embeddings"):
embeddings = extract_embeddings(text, tokenizer, model)
st.write("Embeddings:")
st.write(embeddings)