text embeddings for plot
This commit is contained in:
BIN
plot_embeddings.npy
Normal file
BIN
plot_embeddings.npy
Normal file
Binary file not shown.
Binary file not shown.
@@ -18,10 +18,6 @@ sym_spell = SymSpell(max_dictionary_edit_distance=2, prefix_length=7)
|
||||
stemmer = PorterStemmer()
|
||||
lemmatizer = WordNetLemmatizer()
|
||||
|
||||
# model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
||||
|
||||
# df = pd.read_excel('C:\\Users\\ishaa\\OneDrive\\Documents\\MSU\\Spring 2026\\Data mining\\Project\\sample_data.xlsx', engine='openpyxl')
|
||||
|
||||
def clean_plot(text):
|
||||
text = text.lower()
|
||||
text = text.translate(str.maketrans('', '', string.punctuation)) # Remove punctuation
|
||||
@@ -35,11 +31,12 @@ def clean_plot(text):
|
||||
return text
|
||||
|
||||
def get_genre(row):
|
||||
if pd.isna(row['Genre']):
|
||||
return ""
|
||||
movie = row['Title']
|
||||
print(movie)
|
||||
text = row['Genre']
|
||||
text = text.split(".")[0]
|
||||
text = text.replace(movie, "")
|
||||
text = text.split(".")[0]
|
||||
text = text.lower()
|
||||
match = re.search(r'is a ((?:\S+\s+){4}\S+)', text)
|
||||
if match:
|
||||
@@ -53,13 +50,14 @@ def get_genre(row):
|
||||
return text
|
||||
|
||||
def pre_director(text):
|
||||
if not text:
|
||||
if pd.isna(text) or not text:
|
||||
return ""
|
||||
text = text.lower().strip()
|
||||
return text
|
||||
|
||||
def clean_cast(text, top_k=5):
|
||||
if not text:
|
||||
def clean_cast(text):
|
||||
print(f"Original cast: {text}")
|
||||
if pd.isna(text) or not text:
|
||||
return []
|
||||
text = text.lower()
|
||||
cast_list = [actor.strip() for actor in text.split(",")]
|
||||
|
||||
@@ -2,10 +2,12 @@ import pandas as pd
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from preprocessing import clean_plot, get_genre, pre_director, clean_cast
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import numpy as np
|
||||
|
||||
df = pd.read_excel('C:\\Users\\ishaa\\OneDrive\\Documents\\MSU\\Spring 2026\\Data mining\\Project\\updated_data.xlsx', engine='openpyxl')
|
||||
df = pd.read_excel('C:\\Users\\ishaa\\OneDrive\\Documents\\MSU\\Spring 2026\\Data mining\\Project\\updated_datav2.xlsx', engine='openpyxl')
|
||||
df = df.dropna(subset=['Plot'])
|
||||
|
||||
df = df.dropna(subset=['Genre', 'Plot'])
|
||||
print(len(df))
|
||||
|
||||
df['Processed_Plot'] = df['Plot'].apply(clean_plot)
|
||||
|
||||
@@ -14,4 +16,18 @@ df['Pre_genre'] = df[['Genre', 'Title']].apply(get_genre, axis=1)
|
||||
df['Pre_director'] = df['Director'].apply(pre_director)
|
||||
df['Pre_cast'] = df['Cast'].apply(clean_cast)
|
||||
|
||||
|
||||
# Load embedding model
|
||||
model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
|
||||
# Generate one embedding per movie plot
|
||||
plot_embeddings = model.encode(
|
||||
df["Processed_Plot"].tolist(),
|
||||
show_progress_bar=True,
|
||||
convert_to_numpy=True
|
||||
)
|
||||
|
||||
|
||||
np.save(r"C:\Users\ishaa\OneDrive\Documents\MSU\Spring 2026\Data mining\Project\plot_embeddings.npy",plot_embeddings)
|
||||
|
||||
df.to_excel('C:\\Users\\ishaa\\OneDrive\\Documents\\MSU\\Spring 2026\\Data mining\\Project\\preprocessed_data.xlsx', index=False)
|
||||
BIN
updated_datav2.xlsx
Normal file
BIN
updated_datav2.xlsx
Normal file
Binary file not shown.
Reference in New Issue
Block a user