/*++ Copyright (c) 1989 Microsoft Corporation Module Name: bind.c Abstract: Contains AfdBind for binding an endpoint to a transport address. Author: David Treadwell (davidtr) 25-Feb-1992 Revision History: --*/ #include "afdp.h" NTSTATUS AfdRestartGetAddress ( IN PDEVICE_OBJECT DeviceObject, IN PIRP Irp, IN PVOID Context ); #ifdef ALLOC_PRAGMA #pragma alloc_text( PAGE, AfdBind ) #pragma alloc_text( PAGE, AfdGetAddress ) #pragma alloc_text( PAGEAFD, AfdAreTransportAddressesEqual ) #pragma alloc_text( PAGEAFD, AfdRestartGetAddress ) #endif NTSTATUS AfdBind ( IN PIRP Irp, IN PIO_STACK_LOCATION IrpSp ) /*++ Routine Description: Handles the IOCTL_AFD_BIND IOCTL. Arguments: Irp - Pointer to I/O request packet. IrpSp - pointer to the IO stack location to use for this request. Return Value: NTSTATUS -- Indicates whether the request was successfully queued. --*/ { NTSTATUS status; OBJECT_ATTRIBUTES objectAttributes; IO_STATUS_BLOCK iosb; PTRANSPORT_ADDRESS transportAddress; PTRANSPORT_ADDRESS requestedAddress; ULONG requestedAddressLength; PAFD_ENDPOINT endpoint; PFILE_FULL_EA_INFORMATION ea; ULONG eaBufferLength; PAGED_CODE( ); // // Set up local pointers. // requestedAddress = Irp->AssociatedIrp.SystemBuffer; requestedAddressLength = IrpSp->Parameters.DeviceIoControl.InputBufferLength; endpoint = IrpSp->FileObject->FsContext; ASSERT( IS_AFD_ENDPOINT_TYPE( endpoint ) ); // // Bomb off if this is a helper endpoint. // if ( endpoint->Type == AfdBlockTypeHelper ) { return STATUS_INVALID_PARAMETER; } // // If the client wants a unique address, make sure that there are no // other sockets with this address. ExAcquireResourceExclusive( AfdResource, TRUE ); if ( IrpSp->Parameters.DeviceIoControl.OutputBufferLength != 0 ) { PLIST_ENTRY listEntry; // // Walk the global list of endpoints, // and compare this address againat the address on each endpoint. // for ( listEntry = AfdEndpointListHead.Flink; listEntry != &AfdEndpointListHead; listEntry = listEntry->Flink ) { PAFD_ENDPOINT compareEndpoint; compareEndpoint = CONTAINING_RECORD( listEntry, AFD_ENDPOINT, GlobalEndpointListEntry ); ASSERT( IS_AFD_ENDPOINT_TYPE( compareEndpoint ) ); // // Check whether the endpoint has a local address, whether // the endpoint has been disconnected, and whether the // endpoint is in the process of closing. If any of these // is true, don't compare addresses with this endpoint. // if ( compareEndpoint->LocalAddress != NULL && ( (compareEndpoint->DisconnectMode & (AFD_PARTIAL_DISCONNECT_SEND | AFD_ABORTIVE_DISCONNECT) ) == 0 ) && (compareEndpoint->State != AfdEndpointStateClosing) ) { // // Compare the bits in the endpoint's address and the // address we're attempting to bind to. Note that we // also compare the transport device names on the // endpoints, as it is legal to bind to the same address // on different transports (e.g. bind to same port in // TCP and UDP). We can just compare the transport // device name pointers because unique names are stored // globally. // if ( compareEndpoint->LocalAddressLength == IrpSp->Parameters.DeviceIoControl.InputBufferLength && AfdAreTransportAddressesEqual( compareEndpoint->LocalAddress, compareEndpoint->LocalAddressLength, requestedAddress, requestedAddressLength, FALSE ) && endpoint->TransportInfo == compareEndpoint->TransportInfo ) { // // The addresses are equal. Fail the request. // ExReleaseResource( AfdResource ); Irp->IoStatus.Information = 0; Irp->IoStatus.Status = STATUS_SHARING_VIOLATION; return STATUS_SHARING_VIOLATION; } } } } // // Store the address to which the endpoint is bound. // endpoint->LocalAddress = AFD_ALLOCATE_POOL( NonPagedPool, requestedAddressLength, AFD_LOCAL_ADDRESS_POOL_TAG ); if ( endpoint->LocalAddress == NULL ) { ExReleaseResource( AfdResource ); Irp->IoStatus.Information = 0; Irp->IoStatus.Status = STATUS_INSUFFICIENT_RESOURCES; return STATUS_INSUFFICIENT_RESOURCES; } endpoint->LocalAddressLength = IrpSp->Parameters.DeviceIoControl.InputBufferLength; RtlMoveMemory( endpoint->LocalAddress, requestedAddress, endpoint->LocalAddressLength ); ExReleaseResource( AfdResource ); // // Allocate memory to hold the EA buffer we'll use to specify the // transport address to NtCreateFile. // eaBufferLength = sizeof(FILE_FULL_EA_INFORMATION) - 1 + TDI_TRANSPORT_ADDRESS_LENGTH + 1 + IrpSp->Parameters.DeviceIoControl.InputBufferLength; #if DBG ea = AFD_ALLOCATE_POOL( NonPagedPool, eaBufferLength, AFD_EA_POOL_TAG ); #else ea = AFD_ALLOCATE_POOL( PagedPool, eaBufferLength, AFD_EA_POOL_TAG ); #endif if ( ea == NULL ) { return STATUS_INSUFFICIENT_RESOURCES; } // // Initialize the EA. // ea->NextEntryOffset = 0; ea->Flags = 0; ea->EaNameLength = TDI_TRANSPORT_ADDRESS_LENGTH; ea->EaValueLength = (USHORT)IrpSp->Parameters.DeviceIoControl.InputBufferLength; RtlMoveMemory( ea->EaName, TdiTransportAddress, ea->EaNameLength + 1 ); transportAddress = (PTRANSPORT_ADDRESS)(&ea->EaName[ea->EaNameLength + 1]); RtlMoveMemory( transportAddress, requestedAddress, ea->EaValueLength ); // // Prepare for opening the address object. // InitializeObjectAttributes( &objectAttributes, &endpoint->TransportInfo->TransportDeviceName, OBJ_CASE_INSENSITIVE, // attributes NULL, NULL ); // // Perform the actual open of the address object. // KeAttachProcess( AfdSystemProcess ); status = ZwCreateFile( &endpoint->AddressHandle, GENERIC_READ | GENERIC_WRITE | SYNCHRONIZE, &objectAttributes, &iosb, // returned status information. 0, // block size (unused). 0, // file attributes. FILE_SHARE_READ | FILE_SHARE_WRITE, FILE_CREATE, // create disposition. 0, // create options. ea, eaBufferLength ); AFD_FREE_POOL( ea, AFD_EA_POOL_TAG ); if ( !NT_SUCCESS(status) ) { // // We store the local address in a local before freeing it to // avoid a timing window. // PVOID localAddress = endpoint->LocalAddress; endpoint->LocalAddress = NULL; endpoint->LocalAddressLength = 0; AFD_FREE_POOL( localAddress, AFD_LOCAL_ADDRESS_POOL_TAG ); KeDetachProcess( ); return status; } AfdRecordAddrOpened(); // // Get a pointer to the file object of the address. // status = ObReferenceObjectByHandle( endpoint->AddressHandle, 0L, // DesiredAccess NULL, KernelMode, (PVOID *)&endpoint->AddressFileObject, NULL ); ASSERT( NT_SUCCESS(status) ); AfdRecordAddrRef(); IF_DEBUG(BIND) { KdPrint(( "AfdBind: address file object for endpoint %lx at %lx\n", endpoint, endpoint->AddressFileObject )); } // // Remember the device object to which we need to give requests for // this address object. We can't just use the // fileObject->DeviceObject pointer because there may be a device // attached to the transport protocol. // endpoint->AddressDeviceObject = IoGetRelatedDeviceObject( endpoint->AddressFileObject ); // // Determine whether the TDI provider supports data bufferring. // If the provider doesn't, then we have to do it. // if ( (endpoint->TransportInfo->ProviderInfo.ServiceFlags & TDI_SERVICE_INTERNAL_BUFFERING) != 0 ) { endpoint->TdiBufferring = TRUE; } else { endpoint->TdiBufferring = FALSE; } // // Determine whether the TDI provider is message or stream oriented. // if ( (endpoint->TransportInfo->ProviderInfo.ServiceFlags & TDI_SERVICE_MESSAGE_MODE) != 0 ) { endpoint->TdiMessageMode = TRUE; } else { endpoint->TdiMessageMode = FALSE; } // // Remember that the endpoint has been bound to a transport address. // endpoint->State = AfdEndpointStateBound; // // Set up indication handlers on the address object. Only set up // appropriate event handlers--don't set unnecessary event handlers. // status = AfdSetEventHandler( endpoint->AddressFileObject, TDI_EVENT_ERROR, AfdErrorEventHandler, endpoint ); #if DBG if ( !NT_SUCCESS(status) ) { DbgPrint( "AFD: Setting TDI_EVENT_ERROR failed: %lx\n", status ); } #endif if ( IS_DGRAM_ENDPOINT(endpoint) ) { endpoint->EventsActive = AFD_POLL_SEND; IF_DEBUG(EVENT_SELECT) { KdPrint(( "AfdBind: Endp %08lX, Active %08lX\n", endpoint, endpoint->EventsActive )); } status = AfdSetEventHandler( endpoint->AddressFileObject, TDI_EVENT_RECEIVE_DATAGRAM, AfdReceiveDatagramEventHandler, endpoint ); #if DBG if ( !NT_SUCCESS(status) ) { DbgPrint( "AFD: Setting TDI_EVENT_RECEIVE_DATAGRAM failed: %lx\n", status ); } #endif } else { status = AfdSetEventHandler( endpoint->AddressFileObject, TDI_EVENT_DISCONNECT, AfdDisconnectEventHandler, endpoint ); #if DBG if ( !NT_SUCCESS(status) ) { DbgPrint( "AFD: Setting TDI_EVENT_DISCONNECT failed: %lx\n", status ); } #endif if ( endpoint->TdiBufferring ) { status = AfdSetEventHandler( endpoint->AddressFileObject, TDI_EVENT_RECEIVE, AfdReceiveEventHandler, endpoint ); #if DBG if ( !NT_SUCCESS(status) ) { DbgPrint( "AFD: Setting TDI_EVENT_RECEIVE failed: %lx\n", status ); } #endif status = AfdSetEventHandler( endpoint->AddressFileObject, TDI_EVENT_RECEIVE_EXPEDITED, AfdReceiveExpeditedEventHandler, endpoint ); #if DBG if ( !NT_SUCCESS(status) ) { DbgPrint( "AFD: Setting TDI_EVENT_RECEIVE_EXPEDITED failed: %lx\n", status ); } #endif status = AfdSetEventHandler( endpoint->AddressFileObject, TDI_EVENT_SEND_POSSIBLE, AfdSendPossibleEventHandler, endpoint ); #if DBG if ( !NT_SUCCESS(status) ) { DbgPrint( "AFD: Setting TDI_EVENT_SEND_POSSIBLE failed: %lx\n", status ); } #endif } else { status = AfdSetEventHandler( endpoint->AddressFileObject, TDI_EVENT_RECEIVE, AfdBReceiveEventHandler, endpoint ); #if DBG if ( !NT_SUCCESS(status) ) { DbgPrint( "AFD: Setting TDI_EVENT_RECEIVE failed: %lx\n", status ); } #endif // // Only attempt to set the expedited event handler if the // TDI provider supports expedited data. // if ( (endpoint->TransportInfo->ProviderInfo.ServiceFlags & TDI_SERVICE_EXPEDITED_DATA) != 0 ) { status = AfdSetEventHandler( endpoint->AddressFileObject, TDI_EVENT_RECEIVE_EXPEDITED, AfdBReceiveExpeditedEventHandler, endpoint ); #if DBG if ( !NT_SUCCESS(status) ) { DbgPrint( "AFD: Setting TDI_EVENT_RECEIVE_EXPEDITED failed: %lx\n", status ); } #endif } } } KeDetachProcess( ); Irp->IoStatus.Information = 0; return STATUS_SUCCESS; } // AfdBind NTSTATUS AfdGetAddress ( IN PIRP Irp, IN PIO_STACK_LOCATION IrpSp ) /*++ Routine Description: Handles the IOCTL_AFD_BIND IOCTL. Arguments: Irp - Pointer to I/O request packet. IrpSp - pointer to the IO stack location to use for this request. Return Value: NTSTATUS -- Indicates whether the request was successfully queued. --*/ { NTSTATUS status; PAFD_ENDPOINT endpoint; PFILE_OBJECT fileObject; PDEVICE_OBJECT deviceObject; PAGED_CODE( ); Irp->IoStatus.Information = 0; // // Make sure that the endpoint is in the correct state. // endpoint = IrpSp->FileObject->FsContext; ASSERT( IS_AFD_ENDPOINT_TYPE( endpoint ) ); if ( endpoint->AddressFileObject == NULL && endpoint->State != AfdEndpointStateConnected ) { status = STATUS_INVALID_PARAMETER; goto complete; } // // If the endpoint is connected, use the connection's file object. // Otherwise, use the address file object. Don't use the connection // file object if this is a Netbios endpoint because NETBT cannot // support this TDI feature. // if ( endpoint->Type == AfdBlockTypeVcConnecting && endpoint->Common.VcConnecting.Connection != NULL && endpoint->LocalAddress->Address[0].AddressType != TDI_ADDRESS_TYPE_NETBIOS ) { ASSERT( endpoint->Common.VcConnecting.Connection->Type == AfdBlockTypeConnection ); fileObject = endpoint->Common.VcConnecting.Connection->FileObject; deviceObject = endpoint->Common.VcConnecting.Connection->DeviceObject; } else { fileObject = endpoint->AddressFileObject; deviceObject = endpoint->AddressDeviceObject; } // // Set up the query info to the TDI provider. // ASSERT( Irp->MdlAddress != NULL ); TdiBuildQueryInformation( Irp, deviceObject, fileObject, AfdRestartGetAddress, endpoint, TDI_QUERY_ADDRESS_INFO, Irp->MdlAddress ); // // Call the TDI provider to get the address. // return AfdIoCallDriver( endpoint, deviceObject, Irp ); complete: Irp->IoStatus.Status = status; IoCompleteRequest( Irp, AfdPriorityBoost ); return status; } // AfdGetAddress NTSTATUS AfdRestartGetAddress ( IN PDEVICE_OBJECT DeviceObject, IN PIRP Irp, IN PVOID Context ) { NTSTATUS status; PAFD_ENDPOINT endpoint = Context; KIRQL oldIrql; PMDL mdl; ULONG addressLength; // // If the request succeeded, save the address in the endpoint so // we can use it to handle address sharing. // if ( NT_SUCCESS(Irp->IoStatus.Status) ) { // // First determine the length of the address by walking the MDL // chain. // mdl = Irp->MdlAddress; ASSERT( mdl != NULL ); addressLength = 0; do { addressLength += MmGetMdlByteCount( mdl ); mdl = mdl->Next; } while ( mdl != NULL ); AfdAcquireSpinLock( &AfdSpinLock, &oldIrql ); // // If the new address is longer than the original address, allocate // a new local address buffer. The +4 accounts for the ActivityCount // field that is returned by a query address but is not part // of a TRANSPORT_ADDRESS. // if ( addressLength > endpoint->LocalAddressLength + 4 ) { PVOID newAddress; newAddress = AFD_ALLOCATE_POOL( NonPagedPool, addressLength-4, AFD_LOCAL_ADDRESS_POOL_TAG ); if ( newAddress == NULL ) { AfdReleaseSpinLock( &AfdSpinLock, oldIrql ); return STATUS_INSUFFICIENT_RESOURCES; } AFD_FREE_POOL( endpoint->LocalAddress, AFD_LOCAL_ADDRESS_POOL_TAG ); endpoint->LocalAddress = newAddress; endpoint->LocalAddressLength = addressLength-4; } status = TdiCopyMdlToBuffer( Irp->MdlAddress, 4, endpoint->LocalAddress, 0, endpoint->LocalAddressLength, &endpoint->LocalAddressLength ); ASSERT( NT_SUCCESS(status) ); AfdReleaseSpinLock( &AfdSpinLock, oldIrql ); } AfdCompleteOutstandingIrp( endpoint, Irp ); // // If pending has been returned for this irp then mark the current // stack as pending. // if ( Irp->PendingReturned ) { IoMarkIrpPending( Irp ); } return STATUS_SUCCESS; } // AfdRestartGetAddress CHAR ZeroNodeAddress[6]; BOOLEAN AfdAreTransportAddressesEqual ( IN PTRANSPORT_ADDRESS EndpointAddress, IN ULONG EndpointAddressLength, IN PTRANSPORT_ADDRESS RequestAddress, IN ULONG RequestAddressLength, IN BOOLEAN HonorWildcardIpPortInEndpointAddress ) { if ( EndpointAddress->Address[0].AddressType == TDI_ADDRESS_TYPE_IP && RequestAddress->Address[0].AddressType == TDI_ADDRESS_TYPE_IP ) { TDI_ADDRESS_IP UNALIGNED *ipEndpointAddress; TDI_ADDRESS_IP UNALIGNED *ipRequestAddress; // // They are both IP addresses. If the ports are the same, and // the IP addresses are or _could_be_ the same, then the addresses // are equal. The "cound be" part is true if either IP address // is 0, the "wildcard" IP address. // ipEndpointAddress = (TDI_ADDRESS_IP UNALIGNED *)&EndpointAddress->Address[0].Address[0]; ipRequestAddress = (TDI_ADDRESS_IP UNALIGNED *)&RequestAddress->Address[0].Address[0]; if ( ( ipEndpointAddress->sin_port == ipRequestAddress->sin_port || ( HonorWildcardIpPortInEndpointAddress && ipEndpointAddress->sin_port == 0 ) ) && ( ipEndpointAddress->in_addr == ipRequestAddress->in_addr || ipEndpointAddress->in_addr == 0 || ipRequestAddress->in_addr == 0 ) ) { return TRUE; } // // The addresses are not equal. // return FALSE; } if ( EndpointAddress->Address[0].AddressType == TDI_ADDRESS_TYPE_IPX && RequestAddress->Address[0].AddressType == TDI_ADDRESS_TYPE_IPX ) { TDI_ADDRESS_IPX UNALIGNED *ipxEndpointAddress; TDI_ADDRESS_IPX UNALIGNED *ipxRequestAddress; ipxEndpointAddress = (TDI_ADDRESS_IPX UNALIGNED *)&EndpointAddress->Address[0].Address[0]; ipxRequestAddress = (TDI_ADDRESS_IPX UNALIGNED *)&RequestAddress->Address[0].Address[0]; // // They are both IPX addresses. Check the network addresses // first--if they don't match and both != 0, the addresses // are different. // if ( ipxEndpointAddress->NetworkAddress != ipxRequestAddress->NetworkAddress && ipxEndpointAddress->NetworkAddress != 0 && ipxRequestAddress->NetworkAddress != 0 ) { return FALSE; } // // Now check the node addresses. Again, if they don't match // and neither is 0, the addresses don't match. // ASSERT( ZeroNodeAddress[0] == 0 ); ASSERT( ZeroNodeAddress[1] == 0 ); ASSERT( ZeroNodeAddress[2] == 0 ); ASSERT( ZeroNodeAddress[3] == 0 ); ASSERT( ZeroNodeAddress[4] == 0 ); ASSERT( ZeroNodeAddress[5] == 0 ); if ( !RtlEqualMemory( ipxEndpointAddress->NodeAddress, ipxRequestAddress->NodeAddress, 6 ) && !RtlEqualMemory( ipxEndpointAddress->NodeAddress, ZeroNodeAddress, 6 ) && !RtlEqualMemory( ipxRequestAddress->NodeAddress, ZeroNodeAddress, 6 ) ) { return FALSE; } // // Finally, make sure the socket numbers match. // if ( ipxEndpointAddress->Socket != ipxRequestAddress->Socket ) { return FALSE; } return TRUE; } // // If either address is not of a known address type, then do a // simple memory compare. // return ( EndpointAddressLength == RtlCompareMemory( EndpointAddress, RequestAddress, RequestAddressLength ) ); } // AfdAreTransportAddressesEqual