At my company we use Milvus as our vector database. We had multiple collections in production with customer data. At some point we needed to change the schemas add new fields and remove old one without losing any of that data.
The problem was simple: Milvus had no native ALTER command at the time. The only official workaround mentioned in their GitHub issues was to create a new collection and migrate the data yourself. So that's exactly what I did, but wrapped it into a reusable Django management command so anyone on the team could run it safely.
Nothing fancy. The command takes four things:
It then does this:
Removals are handled automatically if a field isn't in the new schema, it just doesn't get copied. Additions get a default value assigned based on their datatype. You can also pass your own default.
One thing it doesn't support is updating existing values in a field. That would need some changes to existing script but we never needed it so I left it out.
You pass the new schema as a list of field definitions:
schema = [
{"field_name": "file_id", "datatype": "VARCHAR", "max_length": 500},
{"field_name": "vector", "datatype": "FLOAT_VECTOR", "dim": 1024},
{"field_name": "text", "datatype": "VARCHAR", "max_length": 65535},
{"field_name": "metadata", "datatype": "JSON"},
{"field_name": "page_num", "datatype": "VARCHAR", "max_length": 500},
{"field_name": "name", "datatype": "VARCHAR", "max_length": 500}
]
The id
primary key is handled automatically you don't need to include it. You can also assign a default here.
from pymilvus import MilvusClient, DataType
class MilvusManager:
def __init__(self, database):
self.client = MilvusClient(
uri="",
token=""
)
self.client.use_database(database)
def create_schema(self):
schema = self.client.create_schema()
schema.add_field(
field_name="id", datatype=DataType.INT64, is_primary=True, auto_id=True
)
return schema
def create_index(self):
index_params = self.client.prepare_index_params()
return index_params
def create_collection(self, collection_name):
if self.client.has_collection(collection_name):
return
return self.client.create_collection(
collection_name=collection_name,
schema=self.create_schema(),
index_params=self.create_index(),
)
def drop_collection(self, collection_name):
if not self.client.has_collection(collection_name):
return
return self.client.drop_collection(collection_name=collection_name)
def insert_row(self, collection_name, data):
if not self.client.has_collection(collection_name):
self.create_collection(collection_name)
return self.client.insert(collection_name=collection_name, data=data)
def delete_rows(self, collection_name, filter_expr):
self.client.load_collection(collection_name=collection_name)
results = self.client.delete(
collection_name=collection_name, filter=filter_expr
)
self.client.release_collection(collection_name=collection_name)
return results
def query(self, collection_name, filter, output_fields=None):
if not output_fields:
output_fields = ["*"]
self.client.load_collection(collection_name=collection_name)
data = self.client.query(
collection_name=collection_name, filter=filter, output_fields=output_fields
)
self.client.release_collection(collection_name=collection_name)
return data
def upsert(self, collection_name, data):
self.client.load_collection(collection_name=collection_name)
self.client.upsert(collection_name=collection_name, data=data)
self.client.release_collection(collection_name=collection_name)
def close(self):
self.client.close()
Quick Note: It was created for our internal usecase.
import logging
import numpy as np
from milvus_manager import MilvusManager
from pymilvus import DataType
logger = logging.getLogger("django")
def get_field_schema(client, new_schema):
"""
Convert the old schema (JSON) to milvus Collection Schema
"""
logger.info("Creating new schema...")
schema = client.create_schema()
type = {x.name: x.value for x in list(DataType)}
for i, x in enumerate(new_schema):
typ = x.get("datatype")
val = type.get(typ.upper())
if not val:
logger.info(f"Invalid datatype at pos {i} {x}")
raise ValueError("Invalid datatype. Check logs for more info")
x["datatype"] = DataType(val)
try:
schema.add_field(**x)
except:
logger.info(f"Invalid parameter at pos {i} {x}")
raise
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True, auto_id=True)
return schema
def get_clean_fields(fields):
for x in fields:
x["type"] = x["type"].name
if "default_value" in x:
x["default_value"] = x["default_value"].ListFields()[0][-1]
return fields
def get_fields(milvus_manager, collection_name):
"""
Get the field names ffrom old collection
"""
old_schema_fields = milvus_manager.client.describe_collection(collection_name)["fields"]
old_schema_fields = [field for field in old_schema_fields if field["name"] != "id"]
field_names = {f["name"] for f in old_schema_fields}
return old_schema_fields, field_names
def convert_new_schema(new_schema, output_fields):
"""
Add default values to the new schema
"""
fields = []
for field in new_schema:
if field.get("id"):
logger.info("Field id detected.. It will be automatically rewritten")
continue
typ = field.get("datatype")
if field.get("field_name") in output_fields:
fields.append(field)
continue
if typ:
if field.get("default_value"):
fields.append(field)
continue
if typ.lower() == "varchar":
field["default_value"] = ""
if not field.get("max_length"):
field["max_length"] = 65535
elif typ.lower() == "bool":
field["default_value"] = True
elif typ.lower() == "int8":
field["default_value"] = np.int8(0)
elif typ.lower() == "int16":
field["default_value"] = np.int16(0)
elif typ.lower() == "int32":
field["default_value"] = np.int32(0)
elif typ.lower() == "int64":
field["default_value"] = np.int64(0)
elif typ.lower() == "float":
field["default_value"] = np.float32(3.14)
elif typ.lower() == "double":
field["default_value"] = np.float64(3.14)
fields.append(field)
else:
logger.log(f"Data type not provided terminating .... {field}")
raise ValueError("Datatype not provided...")
return fields
def create_collection(milvus_manager, new_collection, schema):
logger.info(f"Creating collection {new_collection}")
index_params = milvus_manager.create_index()
milvus_manager.client.create_collection(
collection_name=new_collection,
schema=schema,
index_params=index_params,
enable_dynamic_field=True,
)
logger.info(f"Collection {new_collection} created")
def get_collection_iterator(milvus_manager, collection_name, output_fields, batch_size):
logger.info(f"Fetching data from {collection_name} with batch size {batch_size}")
milvus_manager.client.load_collection(collection_name)
results = milvus_manager.client.query_iterator(
collection_name=collection_name,
filter="id>0",
batch_size=batch_size,
output_fields=output_fields,
)
logger.info("Data fetched")
return results
def insert_into_collection(milvus_manager, new_collection, results):
milvus_manager.client.load_collection(collection_name=new_collection)
logger.info(f"Inserting into collection {new_collection}")
while True:
result = results.next()
if not result:
break
for x in result:
if "id" in x:
del x["id"]
milvus_manager.client.insert(collection_name=new_collection, data=result)
logger.info(f"Insertion into collection {new_collection} completed")
def cleanup(milvus_manager, collection_name, new_collection):
logger.info(f"Dropping old collection {collection_name}")
milvus_manager.drop_collection(collection_name)
logger.info(f"Old collection {collection_name} dropped")
logger.info(f"Collection {new_collection} renamed")
milvus_manager.client.rename_collection(new_collection, collection_name)
def alter_collection(milvus_manager, collection_name, new_schema, batch_size):
"""
Here we:
1) Drop the new collection if it already exists
2) Compare new fields with old fields and filter out `id` field
3) Add default values to the new schema
4) Convert the new schema (json) to collection schema
5) Get the data in batches from old collection and insert into new collection
6) Delete the old collection and rename the new collection.
"""
new_collection = f"{collection_name}_temp"
milvus_manager.drop_collection(new_collection)
old_fields, old_names = get_fields(milvus_manager, collection_name)
new_names = {f["field_name"] for f in new_schema}
tmp = new_names - old_names
if len(tmp) == 1:
if list(tmp)[0] == "id":
logger.info("No changes detected terminating...")
return
if new_names == old_names:
logger.info("No changes detected terminating...")
return
output_fields = list(new_names.intersection(old_names))
new_schema = convert_new_schema(new_schema, output_fields)
old_fields = get_clean_fields(old_fields)
schema = get_field_schema(milvus_manager.client, new_schema)
results = get_collection_iterator(milvus_manager, collection_name, output_fields, batch_size)
try:
create_collection(milvus_manager, new_collection, schema)
insert_into_collection(milvus_manager, new_collection, results)
cleanup(milvus_manager, collection_name, new_collection)
except Exception as e:
logger.info(f"Got a error for collection {collection_name} : {str(e)} ", exc_info=True)
def main(database, collection_name, new_schema, batch_size):
"""The entrypoint"""
milvus = MilvusManager(database)
alter_collection(milvus, collection_name, new_schema, batch_size)
After this I used this function inside django managemen command.
client.add_collection_field(
collection_name="my_collection",
field_name="new_field",
datatype=DataType.VARCHAR,
max_length=500,
nullable=True
)
It does not support removing old field and adding default.
Thanks for reading! Let's connect: